@@ -141,17 +141,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
141141 return DISPATCH_kv_layout (kv_layout_, KV_LAYOUT, [&] {
142142 return DISPATCH_pos_encoding_mode (
143143 PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
144- cudaError_t status =
145- handler_->BeginForwardDispatched <GROUP_SIZE, HEAD_DIM, PageStorage::kIndices ,
146- KV_LAYOUT, POS_ENCODING_MODE, c_type,
147- nv_half, int32_t >(
148- static_cast <void *>(workspace_buffer.data_ptr ()), workspace_size_in_bytes,
149- static_cast <int32_t *>(indptr.data_ptr ()),
150- static_cast <int32_t *>(last_page_len.data_ptr ()), batch_size, num_qo_heads,
151- page_size);
152- TORCH_CHECK (status == cudaSuccess,
153- " BatchDecodeWithPagedKVCache failed with error " ,
154- cudaGetErrorString (status));
144+ if (handler_->IsCUDAGraphMode ()) {
145+ // NOTE(Zihao): use runtime dispatch because template function is not virtual
146+ auto cuda_graph_handler_ =
147+ dynamic_cast <CUDAGraphBatchDecodeHandler*>(handler_.get ());
148+ cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched <
149+ GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
150+ c_type, nv_half, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
151+ workspace_size_in_bytes,
152+ static_cast <int32_t *>(indptr.data_ptr ()),
153+ static_cast <int32_t *>(last_page_len.data_ptr ()),
154+ batch_size, num_qo_heads, page_size);
155+ TORCH_CHECK (status == cudaSuccess,
156+ " BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error " ,
157+ cudaGetErrorString (status));
158+ } else {
159+ cudaError_t status = handler_->BeginForwardDispatched <
160+ GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
161+ c_type, nv_half, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
162+ workspace_size_in_bytes,
163+ static_cast <int32_t *>(indptr.data_ptr ()),
164+ static_cast <int32_t *>(last_page_len.data_ptr ()),
165+ batch_size, num_qo_heads, page_size);
166+ TORCH_CHECK (status == cudaSuccess,
167+ " BatchDecodeWithPagedKVCache failed with error " ,
168+ cudaGetErrorString (status));
169+ }
155170 return true ;
156171 });
157172 });
@@ -165,17 +180,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
165180 return DISPATCH_kv_layout (kv_layout_, KV_LAYOUT, [&] {
166181 return DISPATCH_pos_encoding_mode (
167182 PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
168- cudaError_t status =
169- handler_->BeginForwardDispatched <GROUP_SIZE, HEAD_DIM, PageStorage::kIndices ,
170- KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
171- int32_t >(
172- static_cast <void *>(workspace_buffer.data_ptr ()), workspace_size_in_bytes,
173- static_cast <int32_t *>(indptr.data_ptr ()),
174- static_cast <int32_t *>(last_page_len.data_ptr ()), batch_size, num_qo_heads,
175- page_size);
176- TORCH_CHECK (status == cudaSuccess,
177- " BatchDecodeWithPagedKVCache failed with error " ,
178- cudaGetErrorString (status));
183+ if (handler_->IsCUDAGraphMode ()) {
184+ // NOTE(Zihao): use runtime dispatch because template function is not virtual
185+ auto cuda_graph_handler_ =
186+ dynamic_cast <CUDAGraphBatchDecodeHandler*>(handler_.get ());
187+ auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched <
188+ GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
189+ c_type, c_type, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
190+ workspace_size_in_bytes,
191+ static_cast <int32_t *>(indptr.data_ptr ()),
192+ static_cast <int32_t *>(last_page_len.data_ptr ()),
193+ batch_size, num_qo_heads, page_size);
194+ TORCH_CHECK (status == cudaSuccess,
195+ " BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error " ,
196+ cudaGetErrorString (status));
197+ } else {
198+ cudaError_t status = handler_->BeginForwardDispatched <
199+ GROUP_SIZE, HEAD_DIM, PageStorage::kIndices , KV_LAYOUT, POS_ENCODING_MODE,
200+ c_type, c_type, int32_t >(static_cast <void *>(workspace_buffer.data_ptr ()),
201+ workspace_size_in_bytes,
202+ static_cast <int32_t *>(indptr.data_ptr ()),
203+ static_cast <int32_t *>(last_page_len.data_ptr ()),
204+ batch_size, num_qo_heads, page_size);
205+ TORCH_CHECK (status == cudaSuccess,
206+ " BatchDecodeWithPagedKVCache failed with error " ,
207+ cudaGetErrorString (status));
208+ }
179209 return true ;
180210 });
181211 });
0 commit comments