Skip to content

Commit 00f526f

Browse files
seperate gpu wait
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent a8439e2 commit 00f526f

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

vllm/v1/worker/ubatching.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __init__(self,
3737
def __enter__(self):
3838
global _CURRENT_CONTEXT
3939
_CURRENT_CONTEXT[threading.get_ident()] = self
40-
self._wait()
40+
self._cpu_wait()
41+
self.gpu_stream_wait()
4142
return self
4243

4344
def __exit__(self, exc_type, exc_val, exc_tb):
@@ -53,20 +54,39 @@ def _restore_context(self):
5354
torch.cuda.set_stream(self.stream)
5455
forward_context._forward_context = self.forward_context
5556

56-
def yield_(self):
57+
# Seperate GPU wait so we can do
58+
# ubatch0
59+
# 1) work
60+
# 2) dispatch
61+
# 3) yield
62+
# ubatch1
63+
# 1) work
64+
# 2) gpu wait
65+
# 3) dispatch
66+
# 4) yield
67+
#
68+
# This way we can have the CPU schedule ubatch1-dispatch while ubatch0
69+
# before yielding back to ubatch1 but ensure we wont start the dispatch
70+
# until ubatch0-dispatch is done avoiding overlapping dispatches that
71+
# might share underlying buffers
72+
def gpu_stream_wait(self):
73+
self.stream.wait_event(self.gpu_wait_event)
74+
75+
def yield_(self, gpu_wait: bool = True):
5776
self._signal()
58-
self._wait()
77+
self._cpu_wait()
78+
if gpu_wait:
79+
self.gpu_stream_wait()
5980

6081
def _signal(self):
6182
# Wait for the next batch to signal back
6283
self.gpu_signal_event.record(self.stream)
6384
# Signal that this batch reached the barrier
6485
self.cpu_signal_event.set()
6586

66-
def _wait(self):
87+
def _cpu_wait(self):
6788
self.cpu_wait_event.wait()
6889
self.cpu_wait_event.clear()
69-
self.stream.wait_event(self.gpu_wait_event)
7090
self._restore_context()
7191

7292
_CURRENT_CONTEXT: dict = {}
@@ -121,7 +141,7 @@ def make_ubatch_context_chain(
121141
ctxs.append(ctx)
122142

123143
def start_hook(from_stream: torch.cuda.Stream):
124-
ctxs[0].cpu_wait_event.set()
125-
ctxs[0].gpu_wait_event.record(from_stream)
144+
ctxs[0].gpu_wait_event.record(from_stream)
145+
ctxs[0].cpu_wait_event.set()
126146

127147
return ctxs, start_hook

0 commit comments

Comments
 (0)