-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Only the first shard is prefilled #1
Comments
Funnily enough, this "bug" is also kinda a feature, it speeds up the prefill stage significantly (which is typically compute bound) and you still get good results out of the model so it's hard to notice it. It tends to fail once you have a very long context and you see context collapse (see e.g. exo-explore/exo#23). |
I am not sure if I understand the issue correctly. Currently, we are doing pipeline parallelization, and only the first shard processing the embedding looks fine unless I am missing something. prompt -> first shard-> second shard -> logits -> token by token generation shard by shard. |
You're right, I was overcomplicating things -- neither exo or mlx_sharding has this issue :) |
First of all, awesome repo, really love what you did here!
This is a bug that exo also has (see exo-explore/exo#12).
The issue is that the prompt is only loaded (prefilled) into the layers in the first shard, see https://github.com/mzbac/mlx_sharding/blob/main/server/model/deepseek_v2.py#L441 - this is the wrong condition for taking a prompt. The prompt should be prefiled into all layers, then generation should begin. An interesting corollary of this is that prefill can happen in parallel across all shards, since there's no dependency - just the prompt.
You can most easily reproduce this bug / notice it with very long contexts OR simply make the first shard only one layer.
exo is offering a $100 bounty if you want to also fix it there :)
The text was updated successfully, but these errors were encountered: