@@ -37,7 +37,8 @@ def __init__(self,
3737 def __enter__ (self ):
3838 global _CURRENT_CONTEXT
3939 _CURRENT_CONTEXT [threading .get_ident ()] = self
40- self ._wait ()
40+ self ._cpu_wait ()
41+ self .gpu_stream_wait ()
4142 return self
4243
4344 def __exit__ (self , exc_type , exc_val , exc_tb ):
@@ -53,20 +54,39 @@ def _restore_context(self):
5354 torch .cuda .set_stream (self .stream )
5455 forward_context ._forward_context = self .forward_context
5556
56- def yield_ (self ):
57+ # Seperate GPU wait so we can do
58+ # ubatch0
59+ # 1) work
60+ # 2) dispatch
61+ # 3) yield
62+ # ubatch1
63+ # 1) work
64+ # 2) gpu wait
65+ # 3) dispatch
66+ # 4) yield
67+ #
68+ # This way we can have the CPU schedule ubatch1-dispatch while ubatch0
69+ # before yielding back to ubatch1 but ensure we wont start the dispatch
70+ # until ubatch0-dispatch is done avoiding overlapping dispatches that
71+ # might share underlying buffers
72+ def gpu_stream_wait (self ):
73+ self .stream .wait_event (self .gpu_wait_event )
74+
75+ def yield_ (self , gpu_wait : bool = True ):
5776 self ._signal ()
58- self ._wait ()
77+ self ._cpu_wait ()
78+ if gpu_wait :
79+ self .gpu_stream_wait ()
5980
6081 def _signal (self ):
6182 # Wait for the next batch to signal back
6283 self .gpu_signal_event .record (self .stream )
6384 # Signal that this batch reached the barrier
6485 self .cpu_signal_event .set ()
6586
66- def _wait (self ):
87+ def _cpu_wait (self ):
6788 self .cpu_wait_event .wait ()
6889 self .cpu_wait_event .clear ()
69- self .stream .wait_event (self .gpu_wait_event )
7090 self ._restore_context ()
7191
7292_CURRENT_CONTEXT : dict = {}
@@ -121,7 +141,7 @@ def make_ubatch_context_chain(
121141 ctxs .append (ctx )
122142
123143 def start_hook (from_stream : torch .cuda .Stream ):
124- ctxs [0 ].cpu_wait_event . set ( )
125- ctxs [0 ].gpu_wait_event . record ( from_stream )
144+ ctxs [0 ].gpu_wait_event . record ( from_stream )
145+ ctxs [0 ].cpu_wait_event . set ( )
126146
127147 return ctxs , start_hook
0 commit comments