-
Notifications
You must be signed in to change notification settings - Fork 481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Storage Cleaner] Speed up unsharding of some legacy checkpoints #488
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Just a couple questions
olmo/checkpoint.py
Outdated
if rank_size == 0: | ||
return | ||
|
||
temp: np.ndarray = torch.zeros(rank_size, dtype=shard0_md.tensor_properties.dtype).numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this temp array? If it's just to get the number of bytes, you should be able to infer that from the data type and size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just type and number of bytes. I've changed the code to assume fp32 (c58f4b4). I already know they are not bf16 at least because .numpy()
fails for bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea when we train in bf16 (or fp16), the main copy of model weights is always fp32
olmo/checkpoint.py
Outdated
temp: np.ndarray = torch.zeros(1, dtype=shard0_md.tensor_properties.dtype).numpy() | ||
numpy_type = temp.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar question here, but looks like it's just to get the data type? It's probably reasonable to assume FP32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c58f4b4 Changed to assume fp32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR changes the unsharding mechanism of legacy checkpoints to use processes and shared memory instead of threads. In one case where the world size was 1024, this implementation brought the unsharding time down from 6 hours to 30 minutes. This implementation is slower than the old one at smaller scales, but that is ok.
An option was to keep the old mechanism around in code too, but since we are trying to get rid of legacy sharded checkpoints it doesn't seem worth to keep that code around.