2222from vllm .engine .arg_utils import EngineArgs
2323from vllm .entrypoints .utils import _validate_truncation_size
2424from vllm .executor .multiproc_worker_utils import set_multiprocessing_worker_envs
25- from vllm .lora .request import LoRARequest
2625from vllm .outputs import CompletionOutput , RequestOutput
2726from vllm .sampling_params import RequestOutputKind , SamplingParams
2827from vllm .transformers_utils .tokenizer_group import init_tokenizer_from_configs
5352from forge .data_models .completion import Completion
5453from forge .data_models .prompt import to_prompt
5554from forge .env import TORCHSTORE_USE_RDMA
56- from forge .interfaces import Policy as GeneratorInterface
5755from forge .observability .metrics import record_metric , Reduce
5856from forge .observability .perf_tracker import Tracer
5957from forge .types import ProcessConfig
6361
6462
6563@dataclass
66- class Generator (GeneratorInterface ):
67- """Instance of a vLLM-based Generator .
64+ class Generator (ForgeActor ):
65+ """Instance of a vLLM-based generator .
6866
6967 This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The
7068 main difference is that all communications are controlled here via Monarch's proc meshes.
7169
7270 Args:
7371 engine_args (EngineArgs): The engine arguments to use for the vLLM engine.
7472 sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine.
75- available_devices (str ): The available devices to use for the vLLM engine.
76- use_dcp (bool): Whether to use DCP for NFS-based weight sync .
73+ use_dcp_for_weight_sync (bool ): Whether to use DCP for NFS-based weight sync. Default depends on
74+ whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled .
7775
7876 Example:
79-
8077 >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service(
8178 ... engine_args=EngineArgs(...),
8279 ... sampling_params=SamplingParams(...),
@@ -89,50 +86,50 @@ class Generator(GeneratorInterface):
8986
9087 engine_args : EngineArgs | Mapping = field (default_factory = EngineArgs )
9188 sampling_params : SamplingParams | Mapping = field (default_factory = SamplingParams )
92- available_devices : str | None = None
93- use_dcp : bool = (
94- TORCHSTORE_USE_RDMA .get_value () == 0
95- ) # torchstore currently only accepts 0 or 1
96- # Remaining variables are initialized in self.setup()
97- lora_request : LoRARequest | None = None
98- tokenization_kwargs : dict = field (default_factory = dict )
99- generator_worker : GeneratorWorker | None = None
89+ use_dcp_for_weight_sync : bool | None = None
10090
10191 def __post_init__ (self ):
10292 super ().__init__ ()
10393 self ._run_task : asyncio .Task | None = None
10494 self ._generator_proc : ProcMesh | None = None
10595 self ._worker_procs : ProcMesh | None = None
96+ self .worker : GeneratorWorker | None = None
10697 self .running = False
10798 self .generator_version : int = 0
10899
109100 if isinstance (self .engine_args , Mapping ):
110101 self .engine_args = EngineArgs (** self .engine_args )
111102 self .engine_args ._is_v1_supported_oracle = lambda * _ : True
103+ self .vllm_config = self .engine_args .create_engine_config (UsageContext .LLM_CLASS )
112104
113105 if isinstance (self .sampling_params , Mapping ):
114106 self .sampling_params = SamplingParams .from_optional (** self .sampling_params )
115107 self .sampling_params .output_kind = RequestOutputKind .FINAL_ONLY
116108
109+ if self .use_dcp_for_weight_sync is None :
110+ self .use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA .get_value ()
111+ logger .debug (f"{ self .use_dcp_for_weight_sync = } " )
112+
113+ @endpoint
114+ async def get_vllm_config (self ) -> VllmConfig :
115+ return self .vllm_config
116+
117+ @endpoint
118+ async def register_worker (self , worker : GeneratorWorker ) -> None :
119+ self .worker = worker
120+ logger .debug ("Registered GeneratorWorker on Generator." )
121+
117122 @classmethod
118123 async def launch ( # pyright: ignore[reportIncompatibleMethodOverride]
119124 cls : type ["Generator" ],
120- * ,
121- engine_args : EngineArgs | Mapping = EngineArgs (),
122- sampling_params : SamplingParams | Mapping = SamplingParams (),
123- available_devices : str | None = None ,
124- use_dcp : bool = (
125- TORCHSTORE_USE_RDMA .get_value () == 0
126- ), # torchstore currently only accepts 0 or 1
125+ * args ,
127126 ** kwargs ,
128127 ) -> "Generator" :
129- """Launch the Generator with its workers.
128+ """Custom launch for the Generator service with its workers.
130129
131130 We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor.
132131 We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers
133132 and the generator in setup.
134-
135- The args here generally should match those in the `__init__` method of the Generator class.
136133 """
137134 # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
138135 process_config : ProcessConfig = ProcessConfig (
@@ -141,60 +138,46 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
141138 with_gpus = cls .with_gpus ,
142139 mesh_name = cls .mesh_name ,
143140 )
144- worker_procs = await get_proc_mesh (process_config = process_config )
145141
146142 # TODO - issues/144 we will want to ensure colocation with workers
147143 # We're currently locating the Generator on the local host proc mesh
148144 # vLLM initialization without setting env variables at proc_mesh creation
149- # level leads to issues.
150- # Once we can create multiple proc meshes on a host mesh, we can ensure
151- # host colocation
145+ # level leads to issues. Once we can create multiple proc meshes on a host mesh,
146+ # we can ensure host colocation
152147 generator_proc_config = copy (process_config )
153148 generator_proc_config .procs = 1
154149 generator_proc_config .hosts = None
155150 generator_proc_config .with_gpus = False
156151 generator_proc = await get_proc_mesh (process_config = generator_proc_config )
157152
158- if isinstance (engine_args , Mapping ):
159- engine_args = EngineArgs (** engine_args )
160- engine_args ._is_v1_supported_oracle = lambda * _ : True # Always default on
161- logger .debug (f"Resolved engine args: { engine_args } " )
162-
163- vllm_config = engine_args .create_engine_config (UsageContext .LLM_CLASS )
164- workers = worker_procs .spawn (
165- "vllm_worker" , GeneratorWorker , vllm_config = vllm_config , use_dcp = use_dcp
166- )
167-
168- if isinstance (sampling_params , Mapping ):
169- sampling_params = SamplingParams .from_optional (** sampling_params )
170- sampling_params .output_kind = RequestOutputKind .FINAL_ONLY
171- logger .debug (f"Resolved sampling params: { sampling_params } " )
172-
173153 # TODO - expand support so name can stick within kwargs
174154 actor_name = kwargs .pop ("name" , cls .__name__ )
175155 generator = generator_proc .spawn (
176156 actor_name ,
177157 cls ,
178- engine_args = engine_args ,
179- sampling_params = sampling_params ,
180- available_devices = available_devices ,
181- generator_worker = workers ,
158+ * args ,
182159 ** kwargs ,
183160 )
161+
162+ worker_procs = await get_proc_mesh (process_config = process_config )
163+ vllm_config = (
164+ await generator .get_vllm_config .call_one ()
165+ ) # Config should be the same across all actors
166+ worker = worker_procs .spawn (
167+ "vllm_worker" , GeneratorWorker , vllm_config = vllm_config
168+ )
169+ await worker .setup .call ()
170+ await generator .register_worker .call (worker )
171+
184172 generator ._generator_proc = generator_proc
185173 generator ._worker_procs = worker_procs
186174 await generator .setup .call ()
175+
187176 return generator
188177
189178 @endpoint
190179 async def setup (self ):
191180 """Mirrors the __init__ of vLLM's LLMEngine."""
192- if self .generator_worker is None :
193- raise RuntimeError (
194- "Geneator worker should not be None. Usually it would be attached to Generator in the ``launch`` method."
195- )
196- await self .generator_worker .setup .call ()
197-
198181 self .request_id = 0
199182 self .requests : dict [str , tuple [ParentRequest | None , asyncio .Future ]] = {}
200183
@@ -204,35 +187,30 @@ async def setup(self):
204187 self .request_lock = asyncio .Condition () # Guard for accepting_requests
205188 self .update_lock = asyncio .Condition () # Guard for updating requests
206189
207- vllm_config : VllmConfig = self .engine_args .create_engine_config (
208- UsageContext .LLM_CLASS
209- )
210- self .max_model_len = vllm_config .model_config .max_model_len
211-
212190 # Setup processors
213191 # TODO: move all processing to the Environment
214192 # TODO: add support for `log_stats` and `mm_registry`
215193 tokenizer = init_tokenizer_from_configs (
216- model_config = vllm_config .model_config ,
217- scheduler_config = vllm_config .scheduler_config ,
218- lora_config = vllm_config .lora_config ,
194+ model_config = self . vllm_config .model_config ,
195+ scheduler_config = self . vllm_config .scheduler_config ,
196+ lora_config = self . vllm_config .lora_config ,
219197 )
220198 self .processor = Processor (
221- vllm_config = vllm_config , tokenizer = tokenizer , mm_registry = None
199+ vllm_config = self . vllm_config , tokenizer = tokenizer , mm_registry = None
222200 )
223201 self .output_processor = OutputProcessor (tokenizer , log_stats = None )
224202
225203 # Configure KV caches
226- kv_cache_configs = await self .generator_worker .setup_kv_cache .call ()
204+ kv_cache_configs = await self .worker .setup_kv_cache .call ()
227205 _ , kv_cache_config = next (kv_cache_configs .items ())
228- vllm_config .cache_config .num_gpu_blocks = kv_cache_config .num_blocks
229- vllm_config .cache_config .num_cpu_blocks = 0
206+ self . vllm_config .cache_config .num_gpu_blocks = kv_cache_config .num_blocks
207+ self . vllm_config .cache_config .num_cpu_blocks = 0
230208
231209 # Setup scheduler
232210 # TODO: Add support for `log_stats`
233- structured_output_manager = StructuredOutputManager (vllm_config )
211+ structured_output_manager = StructuredOutputManager (self . vllm_config )
234212 self .scheduler = Scheduler (
235- vllm_config = vllm_config ,
213+ vllm_config = self . vllm_config ,
236214 kv_cache_config = kv_cache_config ,
237215 structured_output_manager = structured_output_manager ,
238216 include_finished_set = False ,
@@ -262,11 +240,11 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
262240 self .request_id += 1 % sys .maxsize
263241 request_id = str (self .request_id )
264242
265- tokenization_kwargs = self . tokenization_kwargs or {}
243+ tokenization_kwargs = {}
266244 # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
267245 truncate_prompt_tokens = self .sampling_params .truncate_prompt_tokens
268246 _validate_truncation_size (
269- self .max_model_len ,
247+ self .vllm_config . model_config . max_model_len ,
270248 truncate_prompt_tokens ,
271249 tokenization_kwargs ,
272250 )
@@ -275,7 +253,6 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
275253 prompt = {"prompt" : prompt },
276254 params = self .sampling_params ,
277255 arrival_time = None ,
278- lora_request = self .lora_request ,
279256 tokenization_kwargs = tokenization_kwargs ,
280257 trace_headers = None ,
281258 priority = priority ,
@@ -360,9 +337,7 @@ async def run(self) -> None:
360337 self .running = True
361338 while self .running :
362339 scheduler_output = self .scheduler .schedule ()
363- worker_outputs = await self .generator_worker .execute_model .call (
364- scheduler_output
365- )
340+ worker_outputs = await self .worker .execute_model .call (scheduler_output )
366341
367342 # The results of `execute_model` are gathered on the driver rank (rank 0)
368343 _ , worker_output = next (worker_outputs .items ())
@@ -431,8 +406,8 @@ async def update_weights(self, version: int) -> None:
431406 )
432407
433408 logger .debug (f"Starting weight update on { self .__class__ .__name__ } " )
434- # Call update_weights on every generator_worker
435- await self .generator_worker .update_weights .call (version = version )
409+ # Call update_weights on every generator worker
410+ await self .worker .update_weights .call (version = version )
436411 self .generator_version = version
437412
438413 # After updating the weights, we need to reset the KV cache
@@ -511,13 +486,13 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
511486 async def _test_save_model_params (self ):
512487 """Save model parameters before weight update, used for tesing purposes only."""
513488 logger .info ("[Generator] save model parameters for testing." )
514- await self .generator_worker ._test_save_model_params .call ()
489+ await self .worker ._test_save_model_params .call ()
515490
516491 @endpoint
517492 async def _test_validate_model_params (self , validate_fn ):
518493 """Validate updated model params using validate_fn."""
519494 logger .info ("[Generator] start validating model parameters." )
520- return await self .generator_worker ._test_validate_model_params .call (validate_fn )
495+ return await self .worker ._test_validate_model_params .call (validate_fn )
521496
522497
523498@dataclass
@@ -530,17 +505,9 @@ class GeneratorWorker(ForgeActor):
530505 """
531506
532507 vllm_config : VllmConfig
533- state_dict_key : str = "model_state_dict"
534- # TODO: remove this later since no plumbing exists to change this value.
535- # Also, whether to use dcp or not can be inferred from torchstore get() call.
536- use_dcp : bool = True
537-
538- # used for tesing purposes only
508+ # TODO: Remove below param
539509 _test_prev_params = {}
540510
541- def __post_init__ (self ):
542- super ().__init__ ()
543-
544511 @endpoint
545512 async def setup (self ):
546513 self .rank = current_rank ().rank
@@ -602,19 +569,20 @@ async def update_weights(self, version: int) -> None:
602569 prefix = get_param_prefix (version )
603570 matching_keys = await ts .keys (prefix )
604571 dcp_whole_state_dict_key = get_dcp_whole_state_dict_key (version )
572+ use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
605573 loaded_weights = set ()
606574 t = Tracer ("worker_perf/update_weights" , timer = "gpu" )
607575 t .start ()
608- # Entire state dict is stored in a single DCP handle
609- if dcp_whole_state_dict_key in matching_keys :
576+
577+ if use_dcp_for_weight_sync :
610578 dcp_handle = await ts .get (dcp_whole_state_dict_key )
611579 hf_param_names = dcp_handle .param_names
612580 for name in hf_param_names :
613581 param = load_tensor_from_dcp (dcp_handle , name )
614582 loaded = model .load_weights ([(name , param )])
615583 del param
616584 loaded_weights .update (loaded )
617- else : # Load each parameter from torchstore directly without DCP
585+ else :
618586 hf_param_names = [extract_param_name (key ) for key in matching_keys ]
619587 # We can't pass a generator since vllm load_weights is not async.
620588 # Instead, we just call load_weights with one parameter at a time.
@@ -624,6 +592,7 @@ async def update_weights(self, version: int) -> None:
624592 loaded = model .load_weights ([(name , param )])
625593 del param
626594 loaded_weights .update (loaded )
595+
627596 t .stop ()
628597
629598 @endpoint
0 commit comments