Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable diffusion mlx #474

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

pranav4501
Copy link

@pranav4501 pranav4501 commented Nov 20, 2024

Sharded stable diffusion inference for mlx
#159

Changes:

  • Sharded Stable Diffusion 2-1 Base mlx
  • Handled diffusion steps by looping the whole model
  • Added back inference state
  • Modified grpc and proto to support inference
  • New endpoint for image generation
  • Streaming progress for image generation
  • Handling multiple submodels in a single model

Sharding process:

  1. Stable Diffusion contains three models : CLIP( text encoder) , UNET( Denoising Transformer) and VAE (Image encoder and decoder)
  2. Stable diffusion hugging face repo contains a model_index.json and a folder for each model with its config. I combined all the models configs and loaded it to the model.
  3. The shard is then divided into 3 shards of each model (clip, unet, vae). This works something like the whole model is 37 layers of which 22, 8, 7 are the number of layers for each model in that order. So, a shard of (0,27) is made of shard(0,22, 'clip'), shard(0,5,'unet'), shard(0,0,'vae')
  4. Each model is manually sharded into individual layers.
  5. Then, the inference pipeline is clip.encode(text) -> unet.denoise_latent() for 50 steps -> vae.decode_image().
  6. This is implemented as clip.encode(text) if step ==1 -> unet.denoise_latent() for a step -k -> vae.decode_image() if step ==50 . This pipeline is implemented for 50 steps while maintaining the intermediate results and step count in the inference_state

@AlexCheema
Copy link
Contributor

Just so I understand what's going on, how is the model split up between devices? Lets say I have 3 devices with different capabilities, how does that work?

@pranav4501
Copy link
Author

There are no changes to that part. It's how the partition algorithm splits the shards across the devices.

@AlexCheema
Copy link
Contributor

There are no changes to that part. It's how the partition algorithm splits the shards across the devices.

I see. The difference here is the layers are non-uniform. That means they won't necessarily get split proportional to the memory used right?

@pranav4501
Copy link
Author

Yeah, layers are non-uniform, so the split memory isn't exactly proportional to the number of layers. Can we split using the number of params?

@AlexCheema
Copy link
Contributor

Yeah, layers are non-uniform, so the split memory isn't exactly proportional to the number of layers. Can we split using the number of params?

This is probably fine as long as the layers aren't wildly different in size. Do you know roughly how different in size they are?

@pranav4501
Copy link
Author

Unet does have couple larger layers because of upsampled dims and clip text encoder has comparatively smaller layers as it can be easily split similar to llms, made of transformer blocks. We can combine 2 clip layers and split UNET further to make it more uniform.
CLIP ( 1.36GB -> 22 layers : uniformly split), UNET ( 3.46GB -> 10 layers: non-uniform), VAE (346 MB -> 10 layers : non-uniform )

@blindcrone
Copy link
Contributor

I think at some point it would make sense to allow more granular sharding of models than just transformer blocks anyway, and this could involve updating to a memory-footprint heuristic based on dtypes and parameters rather than assuming uniform layer blocks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants