Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux fill fp8 load failed #10244

Open
Suprhimp opened this issue Dec 16, 2024 · 3 comments
Open

flux fill fp8 load failed #10244

Suprhimp opened this issue Dec 16, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Suprhimp
Copy link

Describe the bug

cc. https://huggingface.co/AlekseyCalvin/FluxFillDev_fp8_Diffusers/discussions/1

I want to run flux fill with fp8 for faster inference but it failed

Reproduction

from diffusers import FluxTransformer2DModel, FluxFillPipeline
from transformers import T5EncoderModel
import torch

transformer = FluxTransformer2DModel.from_pretrained("AlekseyCalvin/FluxFillDev_fp8_Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
text_encoder_2 = T5EncoderModel.from_pretrained("AlekseyCalvin/FluxFillDev_fp8_Diffusers", subfolder="text_encoder_2", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16).to("cuda")

or



import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize

bfl_repo = "black-forest-labs/FLUX.1-Fill-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("/home/ubuntu/black-forest-labs_FLUX.1-Fill-dev_flux1-fill-dev_fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")


image = pipe(
    prompt="a white paper cup",
    image=image,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

image.save("flux-fp8-dev.png")


Logs

ValueError: Trying to set a tensor of shape torch.Size([3072, 384]) in "weight" (which has shape torch.Size([3072, 64])), this looks incorrect.

System Info

0.32.0.dev0, python 3.10 ,nvidia A100

Who can help?

@sayakpaul @DN6

@Suprhimp Suprhimp added the bug Something isn't working label Dec 16, 2024
@hlky
Copy link
Collaborator

hlky commented Dec 16, 2024

Looks like AlekseyCalvin/FluxFillDev_fp8_Diffusers has the config for FLUX.1-dev instead of FLUX.1-Fill-dev

@Suprhimp
Copy link
Author

Thanks! it worked in upper case.

I finally test fp8 version and original version and I found that whether I change model fp8, it does not run more faster..

But Finally I want to test quantize version but Does it speed would be same as original even I quantize?

@kryali
Copy link

kryali commented Dec 22, 2024

@Suprhimp you could try the gguf quantize method that was added very recently

    transformer = FluxTransformer2DModel.from_single_file(
        "https://huggingface.co/YarvixPA/FLUX.1-Fill-dev-gguf/blob/main/flux1-fill-dev-Q4_0.gguf",
        quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
        torch_dtype=torch.bfloat16,
    )
    pipe = FluxFillPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-Fill-dev",
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    ).to(device)
    pipe.enable_model_cpu_offload()

    mask_image = Image.fromarray((mask_2d * 255).astype(np.uint8))
    print(f"Applying Flux Fill with prompt: '{fill_prompt}'")
    flux_image = pipe(
        prompt=fill_prompt,
        image=Image.fromarray(image),
        mask_image=mask_image,
        height=image.shape[0],
        width=image.shape[1],
        guidance_scale=30,
        num_inference_steps=50,
        max_sequence_length=512,
        generator=torch.Generator(device.type).manual_seed(0),
    ).images[0]

    print("Saving Flux Fill result...")
    flux_image.save("output/flux_fill_output.png")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants