Skip to content

Commit 6899e95

Browse files
casteryhallenwang28JenniferWang
authored
shared memory multiprocess prefetch for weight update (#430)
Co-authored-by: Allen Wang <[email protected]> Co-authored-by: Jiyue Wang <[email protected]>
1 parent 7b93ece commit 6899e95

File tree

4 files changed

+1476
-5
lines changed

4 files changed

+1476
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dev = [
4242
"tomli>=1.1.0",
4343
"anyio",
4444
"pytest-asyncio",
45+
"multiprocess",
4546
]
4647
oss = [
4748
"torch",

src/forge/actors/generator.py

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from collections.abc import Mapping
1414
from copy import copy
1515
from dataclasses import dataclass, field
16+
from typing import Optional
1617

1718
import torch
1819
import torchstore as ts
19-
from monarch.actor import current_rank, endpoint, ProcMesh
20+
from monarch.actor import current_rank, endpoint, ProcMesh, this_host
21+
2022
from vllm.config import VllmConfig
2123

2224
from vllm.engine.arg_utils import EngineArgs
@@ -60,6 +62,7 @@
6062
from forge.observability.metrics import record_metric, Reduce
6163
from forge.observability.perf_tracker import Tracer
6264
from forge.types import ProcessConfig
65+
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle
6366

6467
logger = logging.getLogger(__name__)
6568
logger.setLevel(logging.INFO)
@@ -92,6 +95,8 @@ class Generator(ForgeActor):
9295
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
9396
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
9497
use_dcp_for_weight_sync: bool | None = None
98+
prefetch_weights_to_shm: bool = True
99+
n_fetcher_procs: int = 8
95100

96101
def __post_init__(self):
97102
super().__init__()
@@ -226,11 +231,61 @@ async def setup(self):
226231
log_stats=None,
227232
)
228233
self._start_processing()
234+
if self.prefetch_weights_to_shm:
235+
self._spawn_fetchers()
236+
237+
def _spawn_fetchers(self):
238+
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory."""
239+
# TODO: this assumes the generator is on the same host as the worker
240+
# and only works for single host generators. Figure out how to support
241+
# generators with workers spanned across multiple hosts.
242+
fetcher_procs = this_host().spawn_procs(
243+
per_host={"procs": self.n_fetcher_procs}
244+
)
245+
self._fetcher_procs = fetcher_procs
246+
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
229247

230248
def _start_processing(self):
231249
if self._run_task is None or self._run_task.done():
232250
self._run_task = asyncio.create_task(self.run())
233251

252+
async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
253+
for handle in state_dict.values():
254+
handle.drop()
255+
256+
async def _fetch_weights(
257+
self,
258+
version: int,
259+
) -> dict[str, SharedTensorHandle]:
260+
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
261+
t = Tracer("generator_perf/_fetch_weights")
262+
t.start()
263+
prefix = get_param_prefix(version)
264+
matching_keys = await ts.keys(prefix)
265+
hf_param_names = [extract_param_name(key) for key in matching_keys]
266+
267+
n_fetchers = self.weight_fetchers.size()
268+
269+
def split_keys(keys):
270+
return [keys[i::n_fetchers] for i in range(n_fetchers)]
271+
272+
futures = []
273+
for i, names in enumerate(split_keys(hf_param_names)):
274+
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
275+
version=version, param_names=names
276+
)
277+
futures.append(fut)
278+
279+
sub_state_dicts = [await fut for fut in futures]
280+
281+
state_dict = {}
282+
for sd in sub_state_dicts:
283+
state_dict.update(sd)
284+
285+
t.stop()
286+
287+
return state_dict
288+
234289
@endpoint
235290
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
236291
"""Generate a response for the given prompt
@@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None:
384439
>>> await trainer.push_weights()
385440
>>> generator.update_weights(version)
386441
"""
442+
# TODO: enable shared memory prefetch for DCP-based weight sync
443+
if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync:
444+
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
445+
fetch_fut = asyncio.create_task(self._fetch_weights(version))
446+
else:
447+
fetch_fut = None
387448
# Serialize updates (only one update at a time)
388449
async with self.update_lock:
389450
# Grab the lock to stop accepting requests and wait on pending requests
@@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None:
415476
)
416477

417478
logger.debug(f"Starting weight update on {self.__class__.__name__}")
418-
# Call update_weights on every generator worker
419-
await self.worker.update_weights.call(version=version)
479+
480+
if fetch_fut is not None:
481+
t = Tracer("generator_perf/waiting_for_fetch_weights")
482+
t.start()
483+
fetched_weights = await fetch_fut
484+
t.stop()
485+
# Call update_weights on every policy_worker
486+
await self.worker.update_weights.call(
487+
shared_memory_state_dict=fetched_weights
488+
)
489+
await self._drop_shared_memory(fetched_weights)
490+
else:
491+
await self.worker.update_weights.call(version=version)
420492
self.generator_version = version
421493

422494
# After updating the weights, we need to reset the KV cache
@@ -490,6 +562,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
490562
await actor.stop.call()
491563
await stop_proc_mesh(actor._worker_procs)
492564
await stop_proc_mesh(actor._generator_proc)
565+
await stop_proc_mesh(actor._fetcher_procs)
493566

494567
@endpoint
495568
async def _test_save_model_params(self):
@@ -573,14 +646,42 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput:
573646
return self.worker.execute_model(schedule)
574647

575648
@endpoint
576-
async def update_weights(self, version: int) -> None:
649+
async def update_weights(
650+
self,
651+
version: Optional[int] = None,
652+
*,
653+
shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None,
654+
) -> None:
577655
model = self.worker.model_runner.model
656+
if shared_memory_state_dict is not None:
657+
logger.info("[PolicyWorker] update weights from shared memory.")
658+
t = Tracer(
659+
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu"
660+
)
661+
t.start()
662+
loaded_weights = set()
663+
for name, param_handle in shared_memory_state_dict.items():
664+
# Use context manager for automatic cleanup
665+
with param_handle.to_shared_tensor() as shared_tensor:
666+
param = shared_tensor.tensor
667+
loaded = model.load_weights([(name, param)])
668+
del param
669+
loaded_weights.update(loaded)
670+
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters")
671+
t.stop()
672+
return
673+
# normal update_weights without shared memory prefetching
674+
if version is None:
675+
raise ValueError(
676+
"version must be provided if not using shared_memory_state_dict"
677+
)
678+
logger.info("[PolicyWorker] update weights from torchstore.")
578679
prefix = get_param_prefix(version)
579680
matching_keys = await ts.keys(prefix)
580681
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
581682
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
582683
loaded_weights = set()
583-
t = Tracer("worker_perf/update_weights", timer="gpu")
684+
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu")
584685
t.start()
585686

586687
if use_dcp_for_weight_sync:
@@ -622,3 +723,27 @@ async def _test_validate_model_params(self, validate_fn):
622723
return validate_fn(
623724
self._test_prev_params, self.worker.model_runner.model, logger
624725
)
726+
727+
728+
class _WeightFetcher(ForgeActor):
729+
"""Fetches weights from torchstore and loads them into shared memory.
730+
This has to be colocated with the GeneratorWorker."""
731+
732+
@endpoint
733+
async def fetch(
734+
self,
735+
*,
736+
version: int,
737+
param_names: list[str],
738+
) -> dict[str, SharedTensorHandle]:
739+
"""Fetch weights from torchstore and load them into shared memory."""
740+
sd = {}
741+
for name in param_names:
742+
param_key = get_param_key(version, name)
743+
param = await ts.get(param_key)
744+
# Use context manager to ensure cleanup after getting handle
745+
with SharedTensor(tensor=param) as shared_tensor:
746+
handle = shared_tensor.get_handle()
747+
sd[name] = handle
748+
del param # Explicitly free the tensor after copying to shared memory
749+
return sd

0 commit comments

Comments
 (0)