|
16 | 16 | import gc |
17 | 17 | import os |
18 | 18 | import sys |
| 19 | +from importlib.util import find_spec |
19 | 20 | from typing import Any, Optional, cast |
20 | 21 |
|
21 | 22 | import ray |
@@ -157,63 +158,134 @@ def __init__( |
157 | 158 | self.rank = 0 |
158 | 159 | self.world_size = 1 |
159 | 160 |
|
160 | | - # Monkey patch for vLLM to ensure RAY_ADDRESS is set in Ray actors. |
161 | | - try: |
162 | | - from vllm.logger import init_logger |
| 161 | + # Monkey patches for vLLM behavior. We avoid importing vllm modules |
| 162 | + # here to prevent side effects during initialization and instead |
| 163 | + # locate the files via importlib metadata. |
163 | 164 |
|
164 | | - logger = init_logger("vllm_patch") |
| 165 | + from vllm.logger import init_logger |
165 | 166 |
|
166 | | - def _patch_vllm_init_workers_ray(): |
167 | | - """Patch the vLLM ray_distributed_executor.py file. |
| 167 | + logger = init_logger("vllm_patch") |
168 | 168 |
|
169 | | - 1. Pass custom runtime_env in _init_workers_ray call. |
170 | | - - This allows passing custom py_executable to worker initialization. |
171 | | - 2. Add NCCL_CUMEM_ENABLE and NCCL_NVLS_ENABLE to vLLM ADDITIONAL_ENV_VARS. |
172 | | - - This is a workaround to fix async vllm in some scenarios. |
173 | | - - See https://github.com/NVIDIA-NeMo/RL/pull/898 for more details. |
174 | | - """ |
175 | | - try: |
176 | | - import vllm.executor.ray_distributed_executor as ray_executor_module |
| 169 | + def _get_vllm_file(relative_path: str) -> str: |
| 170 | + """Return absolute path to a vLLM file or raise if it cannot be found. |
| 171 | +
|
| 172 | + The relative_path should be a POSIX-style path under the vllm |
| 173 | + package root, e.g. "v1/executor/ray_executor.py" or |
| 174 | + "attention/layer.py". |
| 175 | + """ |
| 176 | + spec = find_spec("vllm") |
| 177 | + if spec is None or not spec.submodule_search_locations: |
| 178 | + raise RuntimeError( |
| 179 | + "vLLM package not found while attempting to patch " |
| 180 | + f"'{relative_path}'. Ensure vLLM is installed and " |
| 181 | + "available in this environment." |
| 182 | + ) |
177 | 183 |
|
178 | | - file_to_patch = ray_executor_module.__file__ |
| 184 | + base_dir = next(iter(spec.submodule_search_locations)) |
| 185 | + file_path = os.path.join(base_dir, *relative_path.split("/")) |
179 | 186 |
|
180 | | - with open(file_to_patch, "r") as f: |
181 | | - content = f.read() |
| 187 | + if not os.path.exists(file_path): |
| 188 | + raise RuntimeError( |
| 189 | + "Failed to locate expected vLLM file to patch. " |
| 190 | + f"Looked for '{relative_path}' at '{file_path}'. " |
| 191 | + "This likely indicates an unexpected vLLM installation " |
| 192 | + "layout or version mismatch." |
| 193 | + ) |
| 194 | + |
| 195 | + return file_path |
| 196 | + |
| 197 | + def _patch_vllm_init_workers_ray(): |
| 198 | + """Patch the vLLM ray_distributed_executor.py file. |
| 199 | +
|
| 200 | + 1. Pass custom runtime_env in _init_workers_ray call. |
| 201 | + - This allows passing custom py_executable to worker initialization. |
| 202 | + 2. Add NCCL_CUMEM_ENABLE and NCCL_NVLS_ENABLE to vLLM ADDITIONAL_ENV_VARS. |
| 203 | + - This is a workaround to fix async vllm in some scenarios. |
| 204 | + - See https://github.com/NVIDIA-NeMo/RL/pull/898 for more details. |
| 205 | + """ |
| 206 | + file_to_patch = _get_vllm_file("v1/executor/ray_executor.py") |
182 | 207 |
|
183 | | - old_lines = [ |
184 | | - "self._init_workers_ray(placement_group)", |
185 | | - 'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}', |
186 | | - ] |
| 208 | + with open(file_to_patch, "r") as f: |
| 209 | + content = f.read() |
187 | 210 |
|
188 | | - new_lines = [ |
189 | | - f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{self.py_executable}"}})', |
190 | | - 'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "NCCL_CUMEM_ENABLE", "NCCL_NVLS_ENABLE", "RAY_ENABLE_UV_RUN_RUNTIME_ENV"}', |
191 | | - ] |
| 211 | + old_lines = [ |
| 212 | + "self._init_workers_ray(placement_group)", |
| 213 | + 'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}', |
| 214 | + ] |
192 | 215 |
|
193 | | - need_replace = False |
194 | | - for old_line, new_line in zip(old_lines, new_lines): |
195 | | - if new_line in content or old_line not in content: |
196 | | - continue |
197 | | - content = content.replace(old_line, new_line) |
198 | | - need_replace = True |
| 216 | + new_lines = [ |
| 217 | + f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{self.py_executable}"}})', |
| 218 | + 'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "NCCL_CUMEM_ENABLE", "NCCL_NVLS_ENABLE", "RAY_ENABLE_UV_RUN_RUNTIME_ENV"}', |
| 219 | + ] |
| 220 | + |
| 221 | + need_replace = False |
| 222 | + for old_line, new_line in zip(old_lines, new_lines): |
| 223 | + if new_line in content or old_line not in content: |
| 224 | + continue |
| 225 | + content = content.replace(old_line, new_line) |
| 226 | + need_replace = True |
| 227 | + |
| 228 | + if not need_replace: |
| 229 | + return |
| 230 | + |
| 231 | + # Write back the patched content |
| 232 | + with open(file_to_patch, "w") as f: |
| 233 | + f.write(content) |
| 234 | + |
| 235 | + def _patch_vllm_vit_flash_attn_backend(): |
| 236 | + """Patch vLLM vision attention backend selection logic. |
| 237 | +
|
| 238 | + Modify the CUDA branch of maybe_get_vit_flash_attn_backend in |
| 239 | + vllm.attention.layer to avoid overriding the backend when it |
| 240 | + is already set to XFORMERS. This avoids flash attention related |
| 241 | + errors when the ViT head dimension is not a multiple of 32. |
| 242 | +
|
| 243 | + Related issues: |
| 244 | + - https://github.com/vllm-project/vllm/issues/27562 |
| 245 | + - https://github.com/vllm-project/vllm/issues/26989 |
| 246 | +
|
| 247 | + This is properly fixed in https://github.com/vllm-project/vllm/pull/28763. We can remove this patch once we upgrade to a version of vllm that contains this fix. |
| 248 | + """ |
| 249 | + file_to_patch = _get_vllm_file("attention/layer.py") |
| 250 | + with open(file_to_patch, "r") as f: |
| 251 | + content = f.read() |
| 252 | + |
| 253 | + old_snippet = ( |
| 254 | + " elif current_platform.is_cuda():\n" |
| 255 | + " if (\n" |
| 256 | + " attn_backend != AttentionBackendEnum.FLASH_ATTN\n" |
| 257 | + " and check_upstream_fa_availability(torch.get_default_dtype())\n" |
| 258 | + " ):\n" |
| 259 | + " attn_backend = AttentionBackendEnum.FLASH_ATTN\n" |
| 260 | + " use_upstream_fa = True\n" |
| 261 | + ) |
| 262 | + |
| 263 | + new_snippet = ( |
| 264 | + " elif current_platform.is_cuda():\n" |
| 265 | + " if (\n" |
| 266 | + " attn_backend != AttentionBackendEnum.FLASH_ATTN\n" |
| 267 | + " and attn_backend != AttentionBackendEnum.XFORMERS\n" |
| 268 | + " and check_upstream_fa_availability(torch.get_default_dtype())\n" |
| 269 | + " ):\n" |
| 270 | + " attn_backend = AttentionBackendEnum.FLASH_ATTN\n" |
| 271 | + " use_upstream_fa = True\n" |
| 272 | + ) |
199 | 273 |
|
200 | | - if not need_replace: |
201 | | - return |
| 274 | + # Only patch if the file still has the old snippet and |
| 275 | + # hasn't been patched already. |
| 276 | + if new_snippet in content or old_snippet not in content: |
| 277 | + return |
202 | 278 |
|
203 | | - # Write back the patched content |
204 | | - with open(file_to_patch, "w") as f: |
205 | | - f.write(content) |
| 279 | + content = content.replace(old_snippet, new_snippet) |
206 | 280 |
|
207 | | - except (ImportError, FileNotFoundError, PermissionError): |
208 | | - # Allow failures gracefully |
209 | | - pass |
| 281 | + with open(file_to_patch, "w") as f: |
| 282 | + f.write(content) |
210 | 283 |
|
211 | | - _patch_vllm_init_workers_ray() |
212 | | - logger.info("Successfully patched vllm _init_workers_ray.") |
| 284 | + _patch_vllm_init_workers_ray() |
| 285 | + logger.info("Successfully patched vllm _init_workers_ray.") |
213 | 286 |
|
214 | | - except (ImportError, AttributeError): |
215 | | - # vllm not installed or has a different structure, skipping patch. |
216 | | - pass |
| 287 | + _patch_vllm_vit_flash_attn_backend() |
| 288 | + logger.info("Successfully patched vllm vit flash attention backend.") |
217 | 289 |
|
218 | 290 | try: |
219 | 291 | import vllm |
|
0 commit comments