Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Add runtime weight update API #5723

Open
lyuqin-scale opened this issue Jun 20, 2024 · 4 comments
Open

[RFC]: Add runtime weight update API #5723

lyuqin-scale opened this issue Jun 20, 2024 · 4 comments
Labels

Comments

@lyuqin-scale
Copy link

Motivation.

In online RL training, vLLM can significantly accelerate the rollout stage. To achieve this, we need weight sync from main training process to vLLM worker process, and then call the existing API in vLLM to update the weights by
model_runner.model.load_weights
An example of such implementation can be found in OpenRLHF, https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_worker_wrap.py

However, user has to monkey patch vLLM worker to introduce such behavior. It would be great if vLLM naturally supports weight sync at runtime.

Proposed Change.

  1. Add a NCCL-based weight sync process group during vLLM initialization, so that main process can dist.broadcast weight to vLLM worker process later
  2. Expose a weight sync API, for example:
    def update_weight(self, name, dtype, shape)

then in master process, user can achieve weight sync via the following (modified from OpenRLHF):

for name, param in model.named_parameters():
    # Fire all vllm engines for broadcast
    if torch.distributed.get_rank() == 0:
        shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
        refs = [
            engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
            for engine in self.vllm_engines
        ]

        torch.distributed.broadcast(param.data, 0, group=self._model_update_group)
        ray.get(refs)

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

@youkaichao
Copy link
Sponsor Member

thanks for the information! can you describe, the processes involved, which tensor lives in which device and which process, and what is the desired transfer?

also, cc @hijkzzz from #5477

@lyuqin-scale
Copy link
Author

@youkaichao
process: vLLM worker process(es) on GPU 0 and 1, main training process on GPU 2
tensors: HF weights in main training process on GPU2, to be dist.broadcast to temp tensors of same size to vLLM workers on GPU 0 and 1, then within vLLM worker process:
model_runner.model.load_weights(weights=[(name, weight)])
where weight is the temp tensor of one of the weights broadcasted from main process to vLLM worker process

the transfer is preferred to be via NCCL

@hijkzzz
Copy link

hijkzzz commented Jun 21, 2024

More importantly, we need to support establishing an NCCL group between DeepSpeed and vLLM engines.

@hijkzzz
Copy link

hijkzzz commented Jun 27, 2024

Update:
This API should support LoRA weight updates as much as possible

See:
https://github.com/OpenLLMAI/OpenRLHF/pull/335/files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants