Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using FP8 for inference without CPU offloading can introduce noise. #10302

Open
todochenxi opened this issue Dec 19, 2024 · 3 comments
Open

Using FP8 for inference without CPU offloading can introduce noise. #10302

todochenxi opened this issue Dec 19, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@todochenxi
Copy link

todochenxi commented Dec 19, 2024

Describe the bug

If I use pipe.enable_model_cpu_offload(device=device), the model can perform inference correctly after warming up. However, if I comment out this line, the inference results are noisy.

Reproduction

from diffusers import (
    FluxPipeline, 
    FluxTransformer2DModel
)
from transformers import T5EncoderModel, CLIPTextModel,CLIPTokenizer,T5TokenizerFast
from optimum.quanto import freeze, qfloat8, quantize
import torch
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
dtype = torch.bfloat16
bfl_repo = f"black-forest-labs/FLUX.1-dev" 
device = "cuda"
scheduler       = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", torch_dtype=dtype)
text_encoder    = CLIPTextModel.from_pretrained(bfl_repo, subfolder="text_encoder", torch_dtype=dtype)
tokenizer       = CLIPTokenizer.from_pretrained(bfl_repo, subfolder="tokenizer", torch_dtype=dtype, clean_up_tokenization_spaces=True)
text_encoder_2  = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer_2     = T5TokenizerFast.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, clean_up_tokenization_spaces=True)
vae             = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype)

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline(
            scheduler=scheduler,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            text_encoder_2=text_encoder_2,
            tokenizer_2=tokenizer_2,
            vae=vae,
            transformer=transformer
        ).to(device, dtype=dtype)  # edit

# pipe.enable_model_cpu_offload(device=device)            
params = {
                "prompt": "a cat",
                "num_images_per_prompt": 1,
                "num_inference_steps":1,
                "width": 64,
                "height": 64,
                "guidance_scale": 7,
            }
image = pipe(**params).images[0]    # wamup
params = {
                "prompt": "a cat",
                "num_images_per_prompt": 1,
                "num_inference_steps":25,
                "width": 512,
                "height": 512,
                "guidance_scale": 7,
            }
image = pipe(**params).images[0]    
image.save("1.jpg")

Logs

No response

System Info

WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:
PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.4.1+cu121)
Python 3.10.15 (you have 3.10.13)
Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
Memory-efficient attention, SwiGLU, sparse and more won't be available.
Set XFORMERS_MORE_DETAILS=1 for more details

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-6.8.0-49-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.2
  • Transformers version: 4.46.2
  • Accelerate version: 0.31.0
  • PEFT version: 0.14.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.3
  • xFormers version: 0.0.28.post3
  • Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
    NVIDIA GeForce RTX 3090, 24576 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @DN6

@todochenxi todochenxi added the bug Something isn't working label Dec 19, 2024
@hlky
Copy link
Collaborator

hlky commented Dec 19, 2024

Hi @todochenxi. Could you share an example of the noisy outputs?

@todochenxi
Copy link
Author

Hi @todochenxi. Could you share an example of the noisy outputs?
sure :)

image

@Dylanooo
Copy link

Dylanooo commented Dec 23, 2024

Same result @todochenxi
when optimum-quanto==0.2.6, I got the same result.
when optimum-quanto==0.2.4, t2i was normal, but got an error when loading a lora:
pipeline.load_lora_weights("AWPortrait_CN_1.0.safetensors", adapter_name=lora_name) File "/usr/local/lib/python3.10/dist-packages/diffusers/loaders/lora_pipeline.py", line 1846, in load_lora_weights self.load_lora_into_transformer( File "/usr/local/lib/python3.10/dist-packages/diffusers/loaders/lora_pipeline.py", line 1949, in load_lora_into_transformer incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) File "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py", line 241, in set_peft_model_state_dict load_result = model.load_state_dict(peft_model_state_dict, strict=False) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2201, in load_state_dict load(self, state_dict) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2183, in load module._load_from_state_dict( File "/usr/local/lib/python3.10/dist-packages/optimum/quanto/nn/qmodule.py", line 159, in _load_from_state_dict deserialized_weight = QBytesTensor.load_from_state_dict( File "/usr/local/lib/python3.10/dist-packages/optimum/quanto/tensor/qbytes.py", line 90, in load_from_state_dict inner_tensors_dict[name] = state_dict.pop(prefix + name) KeyError: 'time_text_embed.timestep_embedder.linear_1.weight._data'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants