|  | 
| 13 | 13 | from collections.abc import Mapping | 
| 14 | 14 | from copy import copy | 
| 15 | 15 | from dataclasses import dataclass, field | 
|  | 16 | +from typing import Optional | 
| 16 | 17 | 
 | 
| 17 | 18 | import torch | 
| 18 | 19 | import torchstore as ts | 
| 19 |  | -from monarch.actor import current_rank, endpoint, ProcMesh | 
|  | 20 | +from monarch.actor import current_rank, endpoint, ProcMesh, this_host | 
|  | 21 | + | 
| 20 | 22 | from vllm.config import VllmConfig | 
| 21 | 23 | 
 | 
| 22 | 24 | from vllm.engine.arg_utils import EngineArgs | 
|  | 
| 60 | 62 | from forge.observability.metrics import record_metric, Reduce | 
| 61 | 63 | from forge.observability.perf_tracker import Tracer | 
| 62 | 64 | from forge.types import ProcessConfig | 
|  | 65 | +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle | 
| 63 | 66 | 
 | 
| 64 | 67 | logger = logging.getLogger(__name__) | 
| 65 | 68 | logger.setLevel(logging.INFO) | 
| @@ -92,6 +95,8 @@ class Generator(ForgeActor): | 
| 92 | 95 |     engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) | 
| 93 | 96 |     sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) | 
| 94 | 97 |     use_dcp_for_weight_sync: bool | None = None | 
|  | 98 | +    prefetch_weights_to_shm: bool = True | 
|  | 99 | +    n_fetcher_procs: int = 8 | 
| 95 | 100 | 
 | 
| 96 | 101 |     def __post_init__(self): | 
| 97 | 102 |         super().__init__() | 
| @@ -226,11 +231,61 @@ async def setup(self): | 
| 226 | 231 |             log_stats=None, | 
| 227 | 232 |         ) | 
| 228 | 233 |         self._start_processing() | 
|  | 234 | +        if self.prefetch_weights_to_shm: | 
|  | 235 | +            self._spawn_fetchers() | 
|  | 236 | + | 
|  | 237 | +    def _spawn_fetchers(self): | 
|  | 238 | +        """Spawn weight fetchers that prefetch weights from torchstore to shared memory.""" | 
|  | 239 | +        # TODO: this assumes the generator is on the same host as the worker | 
|  | 240 | +        # and only works for single host generators. Figure out how to support | 
|  | 241 | +        # generators with workers spanned across multiple hosts. | 
|  | 242 | +        fetcher_procs = this_host().spawn_procs( | 
|  | 243 | +            per_host={"procs": self.n_fetcher_procs} | 
|  | 244 | +        ) | 
|  | 245 | +        self._fetcher_procs = fetcher_procs | 
|  | 246 | +        self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) | 
| 229 | 247 | 
 | 
| 230 | 248 |     def _start_processing(self): | 
| 231 | 249 |         if self._run_task is None or self._run_task.done(): | 
| 232 | 250 |             self._run_task = asyncio.create_task(self.run()) | 
| 233 | 251 | 
 | 
|  | 252 | +    async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): | 
|  | 253 | +        for handle in state_dict.values(): | 
|  | 254 | +            handle.drop() | 
|  | 255 | + | 
|  | 256 | +    async def _fetch_weights( | 
|  | 257 | +        self, | 
|  | 258 | +        version: int, | 
|  | 259 | +    ) -> dict[str, SharedTensorHandle]: | 
|  | 260 | +        """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" | 
|  | 261 | +        t = Tracer("generator_perf/_fetch_weights") | 
|  | 262 | +        t.start() | 
|  | 263 | +        prefix = get_param_prefix(version) | 
|  | 264 | +        matching_keys = await ts.keys(prefix) | 
|  | 265 | +        hf_param_names = [extract_param_name(key) for key in matching_keys] | 
|  | 266 | + | 
|  | 267 | +        n_fetchers = self.weight_fetchers.size() | 
|  | 268 | + | 
|  | 269 | +        def split_keys(keys): | 
|  | 270 | +            return [keys[i::n_fetchers] for i in range(n_fetchers)] | 
|  | 271 | + | 
|  | 272 | +        futures = [] | 
|  | 273 | +        for i, names in enumerate(split_keys(hf_param_names)): | 
|  | 274 | +            fut = self.weight_fetchers.slice(procs=i).fetch.call_one( | 
|  | 275 | +                version=version, param_names=names | 
|  | 276 | +            ) | 
|  | 277 | +            futures.append(fut) | 
|  | 278 | + | 
|  | 279 | +        sub_state_dicts = [await fut for fut in futures] | 
|  | 280 | + | 
|  | 281 | +        state_dict = {} | 
|  | 282 | +        for sd in sub_state_dicts: | 
|  | 283 | +            state_dict.update(sd) | 
|  | 284 | + | 
|  | 285 | +        t.stop() | 
|  | 286 | + | 
|  | 287 | +        return state_dict | 
|  | 288 | + | 
| 234 | 289 |     @endpoint | 
| 235 | 290 |     async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: | 
| 236 | 291 |         """Generate a response for the given prompt | 
| @@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None: | 
| 384 | 439 |             >>> await trainer.push_weights() | 
| 385 | 440 |             >>> generator.update_weights(version) | 
| 386 | 441 |         """ | 
|  | 442 | +        # TODO: enable shared memory prefetch for DCP-based weight sync | 
|  | 443 | +        if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: | 
|  | 444 | +            logger.info(f"[Generator] Fetching weights for v{version} to shared memory") | 
|  | 445 | +            fetch_fut = asyncio.create_task(self._fetch_weights(version)) | 
|  | 446 | +        else: | 
|  | 447 | +            fetch_fut = None | 
| 387 | 448 |         # Serialize updates (only one update at a time) | 
| 388 | 449 |         async with self.update_lock: | 
| 389 | 450 |             # Grab the lock to stop accepting requests and wait on pending requests | 
| @@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None: | 
| 415 | 476 |             ) | 
| 416 | 477 | 
 | 
| 417 | 478 |             logger.debug(f"Starting weight update on {self.__class__.__name__}") | 
| 418 |  | -            # Call update_weights on every generator worker | 
| 419 |  | -            await self.worker.update_weights.call(version=version) | 
|  | 479 | + | 
|  | 480 | +            if fetch_fut is not None: | 
|  | 481 | +                t = Tracer("generator_perf/waiting_for_fetch_weights") | 
|  | 482 | +                t.start() | 
|  | 483 | +                fetched_weights = await fetch_fut | 
|  | 484 | +                t.stop() | 
|  | 485 | +                # Call update_weights on every policy_worker | 
|  | 486 | +                await self.worker.update_weights.call( | 
|  | 487 | +                    shared_memory_state_dict=fetched_weights | 
|  | 488 | +                ) | 
|  | 489 | +                await self._drop_shared_memory(fetched_weights) | 
|  | 490 | +            else: | 
|  | 491 | +                await self.worker.update_weights.call(version=version) | 
| 420 | 492 |             self.generator_version = version | 
| 421 | 493 | 
 | 
| 422 | 494 |             # After updating the weights, we need to reset the KV cache | 
| @@ -490,6 +562,7 @@ async def shutdown(  # pyright: ignore[reportIncompatibleMethodOverride] | 
| 490 | 562 |         await actor.stop.call() | 
| 491 | 563 |         await stop_proc_mesh(actor._worker_procs) | 
| 492 | 564 |         await stop_proc_mesh(actor._generator_proc) | 
|  | 565 | +        await stop_proc_mesh(actor._fetcher_procs) | 
| 493 | 566 | 
 | 
| 494 | 567 |     @endpoint | 
| 495 | 568 |     async def _test_save_model_params(self): | 
| @@ -573,14 +646,42 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: | 
| 573 | 646 |         return self.worker.execute_model(schedule) | 
| 574 | 647 | 
 | 
| 575 | 648 |     @endpoint | 
| 576 |  | -    async def update_weights(self, version: int) -> None: | 
|  | 649 | +    async def update_weights( | 
|  | 650 | +        self, | 
|  | 651 | +        version: Optional[int] = None, | 
|  | 652 | +        *, | 
|  | 653 | +        shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, | 
|  | 654 | +    ) -> None: | 
| 577 | 655 |         model = self.worker.model_runner.model | 
|  | 656 | +        if shared_memory_state_dict is not None: | 
|  | 657 | +            logger.info("[PolicyWorker] update weights from shared memory.") | 
|  | 658 | +            t = Tracer( | 
|  | 659 | +                "generator_worker_perf/update_weights_from_shared_memory", timer="gpu" | 
|  | 660 | +            ) | 
|  | 661 | +            t.start() | 
|  | 662 | +            loaded_weights = set() | 
|  | 663 | +            for name, param_handle in shared_memory_state_dict.items(): | 
|  | 664 | +                # Use context manager for automatic cleanup | 
|  | 665 | +                with param_handle.to_shared_tensor() as shared_tensor: | 
|  | 666 | +                    param = shared_tensor.tensor | 
|  | 667 | +                    loaded = model.load_weights([(name, param)]) | 
|  | 668 | +                    del param | 
|  | 669 | +                    loaded_weights.update(loaded) | 
|  | 670 | +            logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") | 
|  | 671 | +            t.stop() | 
|  | 672 | +            return | 
|  | 673 | +        # normal update_weights without shared memory prefetching | 
|  | 674 | +        if version is None: | 
|  | 675 | +            raise ValueError( | 
|  | 676 | +                "version must be provided if not using shared_memory_state_dict" | 
|  | 677 | +            ) | 
|  | 678 | +        logger.info("[PolicyWorker] update weights from torchstore.") | 
| 578 | 679 |         prefix = get_param_prefix(version) | 
| 579 | 680 |         matching_keys = await ts.keys(prefix) | 
| 580 | 681 |         dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) | 
| 581 | 682 |         use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys | 
| 582 | 683 |         loaded_weights = set() | 
| 583 |  | -        t = Tracer("worker_perf/update_weights", timer="gpu") | 
|  | 684 | +        t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu") | 
| 584 | 685 |         t.start() | 
| 585 | 686 | 
 | 
| 586 | 687 |         if use_dcp_for_weight_sync: | 
| @@ -622,3 +723,27 @@ async def _test_validate_model_params(self, validate_fn): | 
| 622 | 723 |         return validate_fn( | 
| 623 | 724 |             self._test_prev_params, self.worker.model_runner.model, logger | 
| 624 | 725 |         ) | 
|  | 726 | + | 
|  | 727 | + | 
|  | 728 | +class _WeightFetcher(ForgeActor): | 
|  | 729 | +    """Fetches weights from torchstore and loads them into shared memory. | 
|  | 730 | +    This has to be colocated with the GeneratorWorker.""" | 
|  | 731 | + | 
|  | 732 | +    @endpoint | 
|  | 733 | +    async def fetch( | 
|  | 734 | +        self, | 
|  | 735 | +        *, | 
|  | 736 | +        version: int, | 
|  | 737 | +        param_names: list[str], | 
|  | 738 | +    ) -> dict[str, SharedTensorHandle]: | 
|  | 739 | +        """Fetch weights from torchstore and load them into shared memory.""" | 
|  | 740 | +        sd = {} | 
|  | 741 | +        for name in param_names: | 
|  | 742 | +            param_key = get_param_key(version, name) | 
|  | 743 | +            param = await ts.get(param_key) | 
|  | 744 | +            # Use context manager to ensure cleanup after getting handle | 
|  | 745 | +            with SharedTensor(tensor=param) as shared_tensor: | 
|  | 746 | +                handle = shared_tensor.get_handle() | 
|  | 747 | +                sd[name] = handle | 
|  | 748 | +            del param  # Explicitly free the tensor after copying to shared memory | 
|  | 749 | +        return sd | 
0 commit comments