Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tried this on NVIDIA Labs SANA Pipe and made absolutely 0 difference #4

Open
FurkanGozukara opened this issue Jan 9, 2025 · 6 comments

Comments

@FurkanGozukara
Copy link

The model and pipe is here

https://github.com/NVlabs/Sana

How i tried is below. the app works, but 0 difference on VRAM or output

def load_model(config_path, model_path):
    global pipe
    if pipe is None:
        if torch.cuda.is_available():
            try:
                pipe = SanaPipeline(config_path)
                pipe.from_pretrained(model_path)
                pipe.register_progress_bar(gr.Progress())
                pipe = convert_model(pipe, splits=4)
                print("model converted")
                return True, "Model loaded successfully"
            except Exception as e:
                return False, f"Error loading model: {str(e)}"
    return True, "Model already loaded"
@FurkanGozukara
Copy link
Author

ok tried here and it broken the output

    def build_sana_model(self, config):
        model_kwargs = model_init_config(config, latent_size=self.latent_size)
        model = build_model(
            config.model.model,
            use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
            **model_kwargs,
        )
        self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
        self.logger.info(
            f"{model.__class__.__name__}:{config.model.model},"
            f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
        )
        model = convert_model(model, splits=4)
        print("model converted build_sana_model")
        return model
2025-01-10 06:49:45 - [Sana] - INFO - Generating sample from ckpt: models/checkpoints/Sana_1600M_2Kpx_BF16.pth
2025-01-10 06:49:45 - [Sana] - WARNING - Missing keys: ['pos_embed', 'blocks.0.mlp.depth_conv.conv.conv2d.weight', 'blocks.0.mlp.depth_conv.conv.conv2d.bias', 'blocks.1.mlp.depth_conv.conv.conv2d.weight', 'blocks.1.mlp.depth_conv.conv.conv2d.bias', 'blocks.2.mlp.depth_conv.conv.conv2d.weight', 'blocks.2.mlp.depth_conv.conv.conv2d.bias', 'blocks.3.mlp.depth_conv.conv.conv2d.weight', 'blocks.3.mlp.depth_conv.conv.conv2d.bias', 'blocks.4.mlp.depth_conv.conv.conv2d.weight', 'blocks.4.mlp.depth_conv.conv.conv2d.bias', 'blocks.5.mlp.depth_conv.conv.conv2d.weight', 'blocks.5.mlp.depth_conv.conv.conv2d.bias', 'blocks.6.mlp.depth_conv.conv.conv2d.weight', 'blocks.6.mlp.depth_conv.conv.conv2d.bias', 'blocks.7.mlp.depth_conv.conv.conv2d.weight', 'blocks.7.mlp.depth_conv.conv.conv2d.bias', 'blocks.8.mlp.depth_conv.conv.conv2d.weight', 'blocks.8.mlp.depth_conv.conv.conv2d.bias', 'blocks.9.mlp.depth_conv.conv.conv2d.weight', 'blocks.9.mlp.depth_conv.conv.conv2d.bias', 'blocks.10.mlp.depth_conv.conv.conv2d.weight', 'blocks.10.mlp.depth_conv.conv.conv2d.bias', 'blocks.11.mlp.depth_conv.conv.conv2d.weight', 'blocks.11.mlp.depth_conv.conv.conv2d.bias', 'blocks.12.mlp.depth_conv.conv.conv2d.weight', 'blocks.12.mlp.depth_conv.conv.conv2d.bias', 'blocks.13.mlp.depth_conv.conv.conv2d.weight', 'blocks.13.mlp.depth_conv.conv.conv2d.bias', 'blocks.14.mlp.depth_conv.conv.conv2d.weight', 'blocks.14.mlp.depth_conv.conv.conv2d.bias', 'blocks.15.mlp.depth_conv.conv.conv2d.weight', 'blocks.15.mlp.depth_conv.conv.conv2d.bias', 'blocks.16.mlp.depth_conv.conv.conv2d.weight', 'blocks.16.mlp.depth_conv.conv.conv2d.bias', 'blocks.17.mlp.depth_conv.conv.conv2d.weight', 'blocks.17.mlp.depth_conv.conv.conv2d.bias', 'blocks.18.mlp.depth_conv.conv.conv2d.weight', 'blocks.18.mlp.depth_conv.conv.conv2d.bias', 'blocks.19.mlp.depth_conv.conv.conv2d.weight', 'blocks.19.mlp.depth_conv.conv.conv2d.bias']
2025-01-10 06:49:45 - [Sana] - WARNING - Unexpected keys: ['blocks.0.mlp.depth_conv.conv.weight', 'blocks.0.mlp.depth_conv.conv.bias', 'blocks.1.mlp.depth_conv.conv.weight', 'blocks.1.mlp.depth_conv.conv.bias', 'blocks.2.mlp.depth_conv.conv.weight', 'blocks.2.mlp.depth_conv.conv.bias', 'blocks.3.mlp.depth_conv.conv.weight', 'blocks.3.mlp.depth_conv.conv.bias', 'blocks.4.mlp.depth_conv.conv.weight', 'blocks.4.mlp.depth_conv.conv.bias', 'blocks.5.mlp.depth_conv.conv.weight', 'blocks.5.mlp.depth_conv.conv.bias', 'blocks.6.mlp.depth_conv.conv.weight', 'blocks.6.mlp.depth_conv.conv.bias', 'blocks.7.mlp.depth_conv.conv.weight', 'blocks.7.mlp.depth_conv.conv.bias', 'blocks.8.mlp.depth_conv.conv.weight', 'blocks.8.mlp.depth_conv.conv.bias', 'blocks.9.mlp.depth_conv.conv.weight', 'blocks.9.mlp.depth_conv.conv.bias', 'blocks.10.mlp.depth_conv.conv.weight', 'blocks.10.mlp.depth_conv.conv.bias', 'blocks.11.mlp.depth_conv.conv.weight', 'blocks.11.mlp.depth_conv.conv.bias', 'blocks.12.mlp.depth_conv.conv.weight', 'blocks.12.mlp.depth_conv.conv.bias', 'blocks.13.mlp.depth_conv.conv.weight', 'blocks.13.mlp.depth_conv.conv.bias', 'blocks.14.mlp.depth_conv.conv.weight', 'blocks.14.mlp.depth_conv.conv.bias', 'blocks.15.mlp.depth_conv.conv.weight', 'blocks.15.mlp.depth_conv.conv.bias', 'blocks.16.mlp.depth_conv.conv.weight', 'blocks.16.mlp.depth_conv.conv.bias', 'blocks.17.mlp.depth_conv.conv.weight', 'blocks.17.mlp.depth_conv.conv.bias', 'blocks.18.mlp.depth_conv.conv.weight', 'blocks.18.mlp.depth_conv.conv.bias', 'blocks.19.mlp.depth_conv.conv.weight', 'blocks.19.mlp.depth_conv.conv.bias']

image

@FurkanGozukara
Copy link
Author

tried vae and it makes no difference on output and vram

def vae_decode(name, vae, latent):
    if name == "sdxl" or name == "sd3":
        latent = (latent.detach() / vae.config.scaling_factor) + vae.config.shift_factor
        vae = convert_model(vae, splits=128)
        print("model converted build_sana_model")
        samples = vae.decode(latent).sample
    elif "dc-ae" in name:
        ae = vae
        ae = convert_model(ae, splits=128)
        print("model converted build_sana_model")
        samples = ae.decode(latent.detach() / ae.cfg.scaling_factor)
    else:
        print("error load vae")
        exit()
    return samples

@lmxyy
Copy link
Collaborator

lmxyy commented Jan 10, 2025

DC-AE is using depth-wise conv, so it may have some bugs. How large is your resolution for SDXL and SD3? If the resolution is small, the VRAM is similar.

@FurkanGozukara
Copy link
Author

DC-AE is using depth-wise conv, so it may have some bugs. How large is your resolution for SDXL and SD3? If the resolution is small, the VRAM is similar.

the resolution is 4096x4096

i tried 1024 splits and it gave some error

512 splits still OOM :)

by the way 2048x2048 model fits into 24 gb vram and using patch_conv still didnt make any vram reduction

@lmxyy
Copy link
Collaborator

lmxyy commented Jan 23, 2025

It seems you've resolved the issue with the VAE tiling in Diffusers, which directly tiles the decoder input. While this approach effectively reduces memory usage, it can come at the cost of some image quality. In contrast, PatchConv may require more GPU memory but guarantees mathematically equivalent results, addressing the PyTorch Conv2D memory issue.

For 2K images, the convolution inputs are not large enough to trigger patchified inference (see this line). However, when increasing the resolution to 4K for SD3.5, PatchConv successfully reduced memory usage to 53GB and the results look correct. Regarding SANA's VAE, I noticed it has already employed the input tiling method in Diffusers.

@Shyryp
Copy link

Shyryp commented Jan 24, 2025

It seems you've resolved the issue with the VAE tiling in Diffusers, which directly tiles the decoder input. While this approach effectively reduces memory usage, it can come at the cost of some image quality. In contrast, PatchConv may require more GPU memory but guarantees mathematically equivalent results, addressing the PyTorch Conv2D memory issue.

For 2K images, the convolution inputs are not large enough to trigger patchified inference (see this line). However, when increasing the resolution to 4K for SD3.5, PatchConv successfully reduced memory usage to 53GB and the results look correct. Regarding SANA's VAE, I noticed it has already employed the input tiling method in Diffusers.

I haven't studied the code with the limitation and the technology in general in detail, but the questions are:
For 2K images (and lower) is it possible to apply patching? Or is it impossible in fact (technically and mathematically)?

Is it theoretically possible to reduce the VRAM requirements with small parameter values ​​(with small sizes of generated images)? Or does it only work in combination with large parameters?

In my opinion, in some sense it is similar to traditional "Tile Upscaling", when the image is split into smaller images and then regenerated in parts (thereby avoiding). Are there any problems with patch boundaries in your implementation and with patches in general? Will patches capture the context of the whole image during generation? Or is there no such problem due to the peculiarities of creating and using patches in your implementation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants