Skip to content

Conversation

@mryab
Copy link
Member

@mryab mryab commented Mar 11, 2023

This PR changes the block initialization logic to use PyTorch meta tensors before assigning values from state_dict. This helps avoid unnecessary memory allocation and parameter initialization steps, which might take a lot of time (and RAM) for large models.

Before this PR: 1 block loaded in ~16 seconds

python -m petals.cli.run_server bigscience/bloom-petals --new_swarm --throughput 100 --num_blocks 1
Mar 11 20:28:06.177 [INFO] Running Petals 1.1.3
Mar 11 20:28:07.274 [INFO] This server is accessible directly
Mar 11 20:28:07.714 [INFO] Connecting to a private swarm, initial peers: []
Mar 11 20:28:07.715 [INFO] Running a server on ['/ip4/172.27.77.70/tcp/41059/p2p/12D3KooWA4oco1QGgLWRuJdVoyhCv1559B9x1tAjvomuKN8ZwYrY', '/ip4/127.0.0.1/tcp/41059/p2p/12D3KooWA4oco1QGgLWRuJdVoyhCv1559B9x1tAjvomuKN8ZwYrY', '/ip6/2a02:6b8:0:3201:d9cc:83ae:8057:2b4e/tcp/40189/p2p/12D3KooWA4oco1QGgLWRuJdVoyhCv1559B9x1tAjvomuKN8ZwYrY', '/ip6/::1/tcp/40189/p2p/12D3KooWA4oco1QGgLWRuJdVoyhCv1559B9x1tAjvomuKN8ZwYrY']
Mar 11 20:28:07.752 [INFO] Model weights will be loaded in 8-bit format
Mar 11 20:28:07.753 [INFO] Attention cache for all blocks will consume up to 0.50 GiB
Mar 11 20:28:07.843 [INFO] Reachability service started
Mar 11 20:28:11.264 [INFO] Announced that blocks [0] are joining
Mar 11 20:28:27.418 [INFO] Loaded bigscience/bloom-petals block 0, <All keys matched successfully>

After: 1 block loaded in ~2 seconds

python -m petals.cli.run_server bigscience/bloom-petals --new_swarm --throughput 100 --num_blocks 1
Mar 11 20:27:32.121 [INFO] Running Petals 1.1.3
Mar 11 20:27:33.257 [INFO] This server is accessible directly
Mar 11 20:27:33.700 [INFO] Connecting to a private swarm, initial peers: []
Mar 11 20:27:33.701 [INFO] Running a server on ['/ip4/172.27.77.70/tcp/40367/p2p/12D3KooWLgrNbQbxYaG9W5CJUwNz68fpe8qjm3x4CWDLbzFg2TzY', '/ip4/127.0.0.1/tcp/40367/p2p/12D3KooWLgrNbQbxYaG9W5CJUwNz68fpe8qjm3x4CWDLbzFg2TzY', '/ip6/2a02:6b8:0:3201:d9cc:83ae:8057:2b4e/tcp/37407/p2p/12D3KooWLgrNbQbxYaG9W5CJUwNz68fpe8qjm3x4CWDLbzFg2TzY', '/ip6/::1/tcp/37407/p2p/12D3KooWLgrNbQbxYaG9W5CJUwNz68fpe8qjm3x4CWDLbzFg2TzY']
Mar 11 20:27:33.730 [INFO] Model weights will be loaded in 8-bit format
Mar 11 20:27:33.731 [INFO] Attention cache for all blocks will consume up to 0.50 GiB
Mar 11 20:27:33.861 [INFO] Reachability service started
Mar 11 20:27:34.968 [INFO] Announced that blocks [0] are joining
Mar 11 20:27:36.625 [INFO] Loaded bigscience/bloom-petals block 0, <All keys matched successfully>

@mryab mryab requested review from borzunov and justheuristic and removed request for justheuristic March 11, 2023 18:25
@borzunov
Copy link
Collaborator

Awesome results!

The total block loading time (including the time to move it on GPU) went down from 57 sec to 25 sec on my machines. This means that, given that the blocks are already downloaded, the servers will spend 2x less time for restarting after being preempted or during rebalancing.

@borzunov borzunov changed the title Init WrappedBloomBlock with meta weights Speed up loading blocks by running init on meta device Mar 12, 2023
@borzunov borzunov changed the title Speed up loading blocks by running init on meta device Speed up loading blocks using init with meta weights Mar 12, 2023
@mryab mryab merged commit 793726b into main Mar 12, 2023
@mryab mryab deleted the init_empty_weights branch March 12, 2023 21:49
borzunov added a commit that referenced this pull request Apr 25, 2023
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()`
- In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920)
- Because of that, blocks and attention caches used float32, which caused OOMs
- This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants