Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage Cleaner] Speed up unsharding of some legacy checkpoints #488

Merged
merged 6 commits into from
Mar 7, 2024

Conversation

2015aroras
Copy link
Collaborator

This PR changes the unsharding mechanism of legacy checkpoints to use processes and shared memory instead of threads. In one case where the world size was 1024, this implementation brought the unsharding time down from 6 hours to 30 minutes. This implementation is slower than the old one at smaller scales, but that is ok.

An option was to keep the old mechanism around in code too, but since we are trying to get rid of legacy sharded checkpoints it doesn't seem worth to keep that code around.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Just a couple questions

if rank_size == 0:
return

temp: np.ndarray = torch.zeros(rank_size, dtype=shard0_md.tensor_properties.dtype).numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this temp array? If it's just to get the number of bytes, you should be able to infer that from the data type and size.

Copy link
Collaborator Author

@2015aroras 2015aroras Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just type and number of bytes. I've changed the code to assume fp32 (c58f4b4). I already know they are not bf16 at least because .numpy() fails for bf16.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea when we train in bf16 (or fp16), the main copy of model weights is always fp32

Comment on lines 1035 to 1036
temp: np.ndarray = torch.zeros(1, dtype=shard0_md.tensor_properties.dtype).numpy()
numpy_type = temp.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question here, but looks like it's just to get the data type? It's probably reasonable to assume FP32.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c58f4b4 Changed to assume fp32

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@2015aroras 2015aroras merged commit 752353b into main Mar 7, 2024
11 checks passed
@2015aroras 2015aroras deleted the shanea/optimize-unsharding-2 branch March 7, 2024 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants