Skip to content

Commit 7c150c8

Browse files
authored
Policy cleaner launch / setup (#401)
1 parent a1167f0 commit 7c150c8

File tree

2 files changed

+83
-122
lines changed

2 files changed

+83
-122
lines changed

src/forge/actors/generator.py

Lines changed: 59 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from vllm.engine.arg_utils import EngineArgs
2323
from vllm.entrypoints.utils import _validate_truncation_size
2424
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
25-
from vllm.lora.request import LoRARequest
2625
from vllm.outputs import CompletionOutput, RequestOutput
2726
from vllm.sampling_params import RequestOutputKind, SamplingParams
2827
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
@@ -53,7 +52,6 @@
5352
from forge.data_models.completion import Completion
5453
from forge.data_models.prompt import to_prompt
5554
from forge.env import TORCHSTORE_USE_RDMA
56-
from forge.interfaces import Policy as GeneratorInterface
5755
from forge.observability.metrics import record_metric, Reduce
5856
from forge.observability.perf_tracker import Tracer
5957
from forge.types import ProcessConfig
@@ -63,20 +61,19 @@
6361

6462

6563
@dataclass
66-
class Generator(GeneratorInterface):
67-
"""Instance of a vLLM-based Generator.
64+
class Generator(ForgeActor):
65+
"""Instance of a vLLM-based generator.
6866
6967
This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The
7068
main difference is that all communications are controlled here via Monarch's proc meshes.
7169
7270
Args:
7371
engine_args (EngineArgs): The engine arguments to use for the vLLM engine.
7472
sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine.
75-
available_devices (str): The available devices to use for the vLLM engine.
76-
use_dcp (bool): Whether to use DCP for NFS-based weight sync.
73+
use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on
74+
whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled.
7775
7876
Example:
79-
8077
>>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service(
8178
... engine_args=EngineArgs(...),
8279
... sampling_params=SamplingParams(...),
@@ -89,50 +86,50 @@ class Generator(GeneratorInterface):
8986

9087
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
9188
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
92-
available_devices: str | None = None
93-
use_dcp: bool = (
94-
TORCHSTORE_USE_RDMA.get_value() == 0
95-
) # torchstore currently only accepts 0 or 1
96-
# Remaining variables are initialized in self.setup()
97-
lora_request: LoRARequest | None = None
98-
tokenization_kwargs: dict = field(default_factory=dict)
99-
generator_worker: GeneratorWorker | None = None
89+
use_dcp_for_weight_sync: bool | None = None
10090

10191
def __post_init__(self):
10292
super().__init__()
10393
self._run_task: asyncio.Task | None = None
10494
self._generator_proc: ProcMesh | None = None
10595
self._worker_procs: ProcMesh | None = None
96+
self.worker: GeneratorWorker | None = None
10697
self.running = False
10798
self.generator_version: int = 0
10899

109100
if isinstance(self.engine_args, Mapping):
110101
self.engine_args = EngineArgs(**self.engine_args)
111102
self.engine_args._is_v1_supported_oracle = lambda *_: True
103+
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)
112104

113105
if isinstance(self.sampling_params, Mapping):
114106
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
115107
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
116108

109+
if self.use_dcp_for_weight_sync is None:
110+
self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value()
111+
logger.debug(f"{self.use_dcp_for_weight_sync=}")
112+
113+
@endpoint
114+
async def get_vllm_config(self) -> VllmConfig:
115+
return self.vllm_config
116+
117+
@endpoint
118+
async def register_worker(self, worker: GeneratorWorker) -> None:
119+
self.worker = worker
120+
logger.debug("Registered GeneratorWorker on Generator.")
121+
117122
@classmethod
118123
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
119124
cls: type["Generator"],
120-
*,
121-
engine_args: EngineArgs | Mapping = EngineArgs(),
122-
sampling_params: SamplingParams | Mapping = SamplingParams(),
123-
available_devices: str | None = None,
124-
use_dcp: bool = (
125-
TORCHSTORE_USE_RDMA.get_value() == 0
126-
), # torchstore currently only accepts 0 or 1
125+
*args,
127126
**kwargs,
128127
) -> "Generator":
129-
"""Launch the Generator with its workers.
128+
"""Custom launch for the Generator service with its workers.
130129
131130
We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor.
132131
We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers
133132
and the generator in setup.
134-
135-
The args here generally should match those in the `__init__` method of the Generator class.
136133
"""
137134
# Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
138135
process_config: ProcessConfig = ProcessConfig(
@@ -141,60 +138,46 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
141138
with_gpus=cls.with_gpus,
142139
mesh_name=cls.mesh_name,
143140
)
144-
worker_procs = await get_proc_mesh(process_config=process_config)
145141

146142
# TODO - issues/144 we will want to ensure colocation with workers
147143
# We're currently locating the Generator on the local host proc mesh
148144
# vLLM initialization without setting env variables at proc_mesh creation
149-
# level leads to issues.
150-
# Once we can create multiple proc meshes on a host mesh, we can ensure
151-
# host colocation
145+
# level leads to issues. Once we can create multiple proc meshes on a host mesh,
146+
# we can ensure host colocation
152147
generator_proc_config = copy(process_config)
153148
generator_proc_config.procs = 1
154149
generator_proc_config.hosts = None
155150
generator_proc_config.with_gpus = False
156151
generator_proc = await get_proc_mesh(process_config=generator_proc_config)
157152

158-
if isinstance(engine_args, Mapping):
159-
engine_args = EngineArgs(**engine_args)
160-
engine_args._is_v1_supported_oracle = lambda *_: True # Always default on
161-
logger.debug(f"Resolved engine args: {engine_args}")
162-
163-
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
164-
workers = worker_procs.spawn(
165-
"vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp
166-
)
167-
168-
if isinstance(sampling_params, Mapping):
169-
sampling_params = SamplingParams.from_optional(**sampling_params)
170-
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
171-
logger.debug(f"Resolved sampling params: {sampling_params}")
172-
173153
# TODO - expand support so name can stick within kwargs
174154
actor_name = kwargs.pop("name", cls.__name__)
175155
generator = generator_proc.spawn(
176156
actor_name,
177157
cls,
178-
engine_args=engine_args,
179-
sampling_params=sampling_params,
180-
available_devices=available_devices,
181-
generator_worker=workers,
158+
*args,
182159
**kwargs,
183160
)
161+
162+
worker_procs = await get_proc_mesh(process_config=process_config)
163+
vllm_config = (
164+
await generator.get_vllm_config.call_one()
165+
) # Config should be the same across all actors
166+
worker = worker_procs.spawn(
167+
"vllm_worker", GeneratorWorker, vllm_config=vllm_config
168+
)
169+
await worker.setup.call()
170+
await generator.register_worker.call(worker)
171+
184172
generator._generator_proc = generator_proc
185173
generator._worker_procs = worker_procs
186174
await generator.setup.call()
175+
187176
return generator
188177

189178
@endpoint
190179
async def setup(self):
191180
"""Mirrors the __init__ of vLLM's LLMEngine."""
192-
if self.generator_worker is None:
193-
raise RuntimeError(
194-
"Geneator worker should not be None. Usually it would be attached to Generator in the ``launch`` method."
195-
)
196-
await self.generator_worker.setup.call()
197-
198181
self.request_id = 0
199182
self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {}
200183

@@ -204,35 +187,30 @@ async def setup(self):
204187
self.request_lock = asyncio.Condition() # Guard for accepting_requests
205188
self.update_lock = asyncio.Condition() # Guard for updating requests
206189

207-
vllm_config: VllmConfig = self.engine_args.create_engine_config(
208-
UsageContext.LLM_CLASS
209-
)
210-
self.max_model_len = vllm_config.model_config.max_model_len
211-
212190
# Setup processors
213191
# TODO: move all processing to the Environment
214192
# TODO: add support for `log_stats` and `mm_registry`
215193
tokenizer = init_tokenizer_from_configs(
216-
model_config=vllm_config.model_config,
217-
scheduler_config=vllm_config.scheduler_config,
218-
lora_config=vllm_config.lora_config,
194+
model_config=self.vllm_config.model_config,
195+
scheduler_config=self.vllm_config.scheduler_config,
196+
lora_config=self.vllm_config.lora_config,
219197
)
220198
self.processor = Processor(
221-
vllm_config=vllm_config, tokenizer=tokenizer, mm_registry=None
199+
vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None
222200
)
223201
self.output_processor = OutputProcessor(tokenizer, log_stats=None)
224202

225203
# Configure KV caches
226-
kv_cache_configs = await self.generator_worker.setup_kv_cache.call()
204+
kv_cache_configs = await self.worker.setup_kv_cache.call()
227205
_, kv_cache_config = next(kv_cache_configs.items())
228-
vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
229-
vllm_config.cache_config.num_cpu_blocks = 0
206+
self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
207+
self.vllm_config.cache_config.num_cpu_blocks = 0
230208

231209
# Setup scheduler
232210
# TODO: Add support for `log_stats`
233-
structured_output_manager = StructuredOutputManager(vllm_config)
211+
structured_output_manager = StructuredOutputManager(self.vllm_config)
234212
self.scheduler = Scheduler(
235-
vllm_config=vllm_config,
213+
vllm_config=self.vllm_config,
236214
kv_cache_config=kv_cache_config,
237215
structured_output_manager=structured_output_manager,
238216
include_finished_set=False,
@@ -262,11 +240,11 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
262240
self.request_id += 1 % sys.maxsize
263241
request_id = str(self.request_id)
264242

265-
tokenization_kwargs = self.tokenization_kwargs or {}
243+
tokenization_kwargs = {}
266244
# TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
267245
truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
268246
_validate_truncation_size(
269-
self.max_model_len,
247+
self.vllm_config.model_config.max_model_len,
270248
truncate_prompt_tokens,
271249
tokenization_kwargs,
272250
)
@@ -275,7 +253,6 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
275253
prompt={"prompt": prompt},
276254
params=self.sampling_params,
277255
arrival_time=None,
278-
lora_request=self.lora_request,
279256
tokenization_kwargs=tokenization_kwargs,
280257
trace_headers=None,
281258
priority=priority,
@@ -360,9 +337,7 @@ async def run(self) -> None:
360337
self.running = True
361338
while self.running:
362339
scheduler_output = self.scheduler.schedule()
363-
worker_outputs = await self.generator_worker.execute_model.call(
364-
scheduler_output
365-
)
340+
worker_outputs = await self.worker.execute_model.call(scheduler_output)
366341

367342
# The results of `execute_model` are gathered on the driver rank (rank 0)
368343
_, worker_output = next(worker_outputs.items())
@@ -431,8 +406,8 @@ async def update_weights(self, version: int) -> None:
431406
)
432407

433408
logger.debug(f"Starting weight update on {self.__class__.__name__}")
434-
# Call update_weights on every generator_worker
435-
await self.generator_worker.update_weights.call(version=version)
409+
# Call update_weights on every generator worker
410+
await self.worker.update_weights.call(version=version)
436411
self.generator_version = version
437412

438413
# After updating the weights, we need to reset the KV cache
@@ -511,13 +486,13 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
511486
async def _test_save_model_params(self):
512487
"""Save model parameters before weight update, used for tesing purposes only."""
513488
logger.info("[Generator] save model parameters for testing.")
514-
await self.generator_worker._test_save_model_params.call()
489+
await self.worker._test_save_model_params.call()
515490

516491
@endpoint
517492
async def _test_validate_model_params(self, validate_fn):
518493
"""Validate updated model params using validate_fn."""
519494
logger.info("[Generator] start validating model parameters.")
520-
return await self.generator_worker._test_validate_model_params.call(validate_fn)
495+
return await self.worker._test_validate_model_params.call(validate_fn)
521496

522497

523498
@dataclass
@@ -530,17 +505,9 @@ class GeneratorWorker(ForgeActor):
530505
"""
531506

532507
vllm_config: VllmConfig
533-
state_dict_key: str = "model_state_dict"
534-
# TODO: remove this later since no plumbing exists to change this value.
535-
# Also, whether to use dcp or not can be inferred from torchstore get() call.
536-
use_dcp: bool = True
537-
538-
# used for tesing purposes only
508+
# TODO: Remove below param
539509
_test_prev_params = {}
540510

541-
def __post_init__(self):
542-
super().__init__()
543-
544511
@endpoint
545512
async def setup(self):
546513
self.rank = current_rank().rank
@@ -602,19 +569,20 @@ async def update_weights(self, version: int) -> None:
602569
prefix = get_param_prefix(version)
603570
matching_keys = await ts.keys(prefix)
604571
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
572+
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
605573
loaded_weights = set()
606574
t = Tracer("worker_perf/update_weights", timer="gpu")
607575
t.start()
608-
# Entire state dict is stored in a single DCP handle
609-
if dcp_whole_state_dict_key in matching_keys:
576+
577+
if use_dcp_for_weight_sync:
610578
dcp_handle = await ts.get(dcp_whole_state_dict_key)
611579
hf_param_names = dcp_handle.param_names
612580
for name in hf_param_names:
613581
param = load_tensor_from_dcp(dcp_handle, name)
614582
loaded = model.load_weights([(name, param)])
615583
del param
616584
loaded_weights.update(loaded)
617-
else: # Load each parameter from torchstore directly without DCP
585+
else:
618586
hf_param_names = [extract_param_name(key) for key in matching_keys]
619587
# We can't pass a generator since vllm load_weights is not async.
620588
# Instead, we just call load_weights with one parameter at a time.
@@ -624,6 +592,7 @@ async def update_weights(self, version: int) -> None:
624592
loaded = model.load_weights([(name, param)])
625593
del param
626594
loaded_weights.update(loaded)
595+
627596
t.stop()
628597

629598
@endpoint

0 commit comments

Comments
 (0)