Skip to content

Commit baf37d6

Browse files
committed
Patch vllm vit to use xformers when head dim not divisible by 32
Related: vllm-project/vllm#27562 Signed-off-by: Yi-Fu Wu <[email protected]>
1 parent fa2ccf4 commit baf37d6

File tree

1 file changed

+116
-44
lines changed

1 file changed

+116
-44
lines changed

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 116 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import gc
1717
import os
1818
import sys
19+
from importlib.util import find_spec
1920
from typing import Any, Optional, cast
2021

2122
import ray
@@ -157,63 +158,134 @@ def __init__(
157158
self.rank = 0
158159
self.world_size = 1
159160

160-
# Monkey patch for vLLM to ensure RAY_ADDRESS is set in Ray actors.
161-
try:
162-
from vllm.logger import init_logger
161+
# Monkey patches for vLLM behavior. We avoid importing vllm modules
162+
# here to prevent side effects during initialization and instead
163+
# locate the files via importlib metadata.
163164

164-
logger = init_logger("vllm_patch")
165+
from vllm.logger import init_logger
165166

166-
def _patch_vllm_init_workers_ray():
167-
"""Patch the vLLM ray_distributed_executor.py file.
167+
logger = init_logger("vllm_patch")
168168

169-
1. Pass custom runtime_env in _init_workers_ray call.
170-
- This allows passing custom py_executable to worker initialization.
171-
2. Add NCCL_CUMEM_ENABLE and NCCL_NVLS_ENABLE to vLLM ADDITIONAL_ENV_VARS.
172-
- This is a workaround to fix async vllm in some scenarios.
173-
- See https://github.com/NVIDIA-NeMo/RL/pull/898 for more details.
174-
"""
175-
try:
176-
import vllm.executor.ray_distributed_executor as ray_executor_module
169+
def _get_vllm_file(relative_path: str) -> str:
170+
"""Return absolute path to a vLLM file or raise if it cannot be found.
171+
172+
The relative_path should be a POSIX-style path under the vllm
173+
package root, e.g. "v1/executor/ray_executor.py" or
174+
"attention/layer.py".
175+
"""
176+
spec = find_spec("vllm")
177+
if spec is None or not spec.submodule_search_locations:
178+
raise RuntimeError(
179+
"vLLM package not found while attempting to patch "
180+
f"'{relative_path}'. Ensure vLLM is installed and "
181+
"available in this environment."
182+
)
177183

178-
file_to_patch = ray_executor_module.__file__
184+
base_dir = next(iter(spec.submodule_search_locations))
185+
file_path = os.path.join(base_dir, *relative_path.split("/"))
179186

180-
with open(file_to_patch, "r") as f:
181-
content = f.read()
187+
if not os.path.exists(file_path):
188+
raise RuntimeError(
189+
"Failed to locate expected vLLM file to patch. "
190+
f"Looked for '{relative_path}' at '{file_path}'. "
191+
"This likely indicates an unexpected vLLM installation "
192+
"layout or version mismatch."
193+
)
194+
195+
return file_path
196+
197+
def _patch_vllm_init_workers_ray():
198+
"""Patch the vLLM ray_distributed_executor.py file.
199+
200+
1. Pass custom runtime_env in _init_workers_ray call.
201+
- This allows passing custom py_executable to worker initialization.
202+
2. Add NCCL_CUMEM_ENABLE and NCCL_NVLS_ENABLE to vLLM ADDITIONAL_ENV_VARS.
203+
- This is a workaround to fix async vllm in some scenarios.
204+
- See https://github.com/NVIDIA-NeMo/RL/pull/898 for more details.
205+
"""
206+
file_to_patch = _get_vllm_file("v1/executor/ray_executor.py")
182207

183-
old_lines = [
184-
"self._init_workers_ray(placement_group)",
185-
'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}',
186-
]
208+
with open(file_to_patch, "r") as f:
209+
content = f.read()
187210

188-
new_lines = [
189-
f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{self.py_executable}"}})',
190-
'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "NCCL_CUMEM_ENABLE", "NCCL_NVLS_ENABLE", "RAY_ENABLE_UV_RUN_RUNTIME_ENV"}',
191-
]
211+
old_lines = [
212+
"self._init_workers_ray(placement_group)",
213+
'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}',
214+
]
192215

193-
need_replace = False
194-
for old_line, new_line in zip(old_lines, new_lines):
195-
if new_line in content or old_line not in content:
196-
continue
197-
content = content.replace(old_line, new_line)
198-
need_replace = True
216+
new_lines = [
217+
f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{self.py_executable}"}})',
218+
'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "NCCL_CUMEM_ENABLE", "NCCL_NVLS_ENABLE", "RAY_ENABLE_UV_RUN_RUNTIME_ENV"}',
219+
]
220+
221+
need_replace = False
222+
for old_line, new_line in zip(old_lines, new_lines):
223+
if new_line in content or old_line not in content:
224+
continue
225+
content = content.replace(old_line, new_line)
226+
need_replace = True
227+
228+
if not need_replace:
229+
return
230+
231+
# Write back the patched content
232+
with open(file_to_patch, "w") as f:
233+
f.write(content)
234+
235+
def _patch_vllm_vit_flash_attn_backend():
236+
"""Patch vLLM vision attention backend selection logic.
237+
238+
Modify the CUDA branch of maybe_get_vit_flash_attn_backend in
239+
vllm.attention.layer to avoid overriding the backend when it
240+
is already set to XFORMERS. This avoids flash attention related
241+
errors when the ViT head dimension is not a multiple of 32.
242+
243+
Related issues:
244+
- https://github.com/vllm-project/vllm/issues/27562
245+
- https://github.com/vllm-project/vllm/issues/26989
246+
247+
This is properly fixed in https://github.com/vllm-project/vllm/pull/28763. We can remove this patch once we upgrade to a version of vllm that contains this fix.
248+
"""
249+
file_to_patch = _get_vllm_file("attention/layer.py")
250+
with open(file_to_patch, "r") as f:
251+
content = f.read()
252+
253+
old_snippet = (
254+
" elif current_platform.is_cuda():\n"
255+
" if (\n"
256+
" attn_backend != AttentionBackendEnum.FLASH_ATTN\n"
257+
" and check_upstream_fa_availability(torch.get_default_dtype())\n"
258+
" ):\n"
259+
" attn_backend = AttentionBackendEnum.FLASH_ATTN\n"
260+
" use_upstream_fa = True\n"
261+
)
262+
263+
new_snippet = (
264+
" elif current_platform.is_cuda():\n"
265+
" if (\n"
266+
" attn_backend != AttentionBackendEnum.FLASH_ATTN\n"
267+
" and attn_backend != AttentionBackendEnum.XFORMERS\n"
268+
" and check_upstream_fa_availability(torch.get_default_dtype())\n"
269+
" ):\n"
270+
" attn_backend = AttentionBackendEnum.FLASH_ATTN\n"
271+
" use_upstream_fa = True\n"
272+
)
199273

200-
if not need_replace:
201-
return
274+
# Only patch if the file still has the old snippet and
275+
# hasn't been patched already.
276+
if new_snippet in content or old_snippet not in content:
277+
return
202278

203-
# Write back the patched content
204-
with open(file_to_patch, "w") as f:
205-
f.write(content)
279+
content = content.replace(old_snippet, new_snippet)
206280

207-
except (ImportError, FileNotFoundError, PermissionError):
208-
# Allow failures gracefully
209-
pass
281+
with open(file_to_patch, "w") as f:
282+
f.write(content)
210283

211-
_patch_vllm_init_workers_ray()
212-
logger.info("Successfully patched vllm _init_workers_ray.")
284+
_patch_vllm_init_workers_ray()
285+
logger.info("Successfully patched vllm _init_workers_ray.")
213286

214-
except (ImportError, AttributeError):
215-
# vllm not installed or has a different structure, skipping patch.
216-
pass
287+
_patch_vllm_vit_flash_attn_backend()
288+
logger.info("Successfully patched vllm vit flash attention backend.")
217289

218290
try:
219291
import vllm

0 commit comments

Comments
 (0)