-
Notifications
You must be signed in to change notification settings - Fork 834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable diffusion mlx #474
base: main
Are you sure you want to change the base?
Stable diffusion mlx #474
Conversation
Just so I understand what's going on, how is the model split up between devices? Lets say I have 3 devices with different capabilities, how does that work? |
There are no changes to that part. It's how the partition algorithm splits the shards across the devices. |
I see. The difference here is the layers are non-uniform. That means they won't necessarily get split proportional to the memory used right? |
Yeah, layers are non-uniform, so the split memory isn't exactly proportional to the number of layers. Can we split using the number of params? |
This is probably fine as long as the layers aren't wildly different in size. Do you know roughly how different in size they are? |
Unet does have couple larger layers because of upsampled dims and clip text encoder has comparatively smaller layers as it can be easily split similar to llms, made of transformer blocks. We can combine 2 clip layers and split UNET further to make it more uniform. |
I think at some point it would make sense to allow more granular sharding of models than just transformer blocks anyway, and this could involve updating to a memory-footprint heuristic based on dtypes and parameters rather than assuming uniform layer blocks |
Sharded stable diffusion inference for mlx
#159
Changes:
Sharding process: