Skip to content

Commit 9081522

Browse files
committed
rebase
1 parent a07e86f commit 9081522

File tree

1 file changed

+3
-259
lines changed

1 file changed

+3
-259
lines changed

python/flashinfer/decode.py

Lines changed: 3 additions & 259 deletions
Original file line numberDiff line numberDiff line change
@@ -824,267 +824,11 @@ def __init__(
824824
should be large enough to store the maximum batch size (``[max_batch_size]``)
825825
during the lifecycle of this wrapper.
826826
"""
827-
check_kv_layout(kv_layout)
828-
self._kv_layout = kv_layout
829-
self._workspace_buffer = workspace_buffer
830-
max_batch_size = len(last_page_len_buffer)
831-
self._wrapper = _kernels.CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(
832-
TensorLayout[kv_layout].value,
833-
max_batch_size,
834-
)
835-
self._paged_kv_indptr_buf = indptr_buffer
836-
self._paged_kv_indices_buf = indices_buffer
837-
self._paged_kv_last_page_len_buf = last_page_len_buffer
838-
839-
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor):
840-
r"""Reset the workspace buffer.
841-
842-
Parameters
843-
----------
844-
new_workspace_buffer : torch.Tensor
845-
The new workspace buffer, the device of the new workspace buffer should
846-
be the same as the device of the input tensors.
847-
"""
848-
self._workspace_buffer = new_workspace_buffer
849-
850-
def begin_forward(
851-
self,
852-
indptr: torch.Tensor,
853-
indices: torch.Tensor,
854-
last_page_len: torch.Tensor,
855-
num_qo_heads: int,
856-
num_kv_heads: int,
857-
head_dim: int,
858-
page_size: int,
859-
pos_encoding_mode: str = "NONE",
860-
data_type: Union[str, torch.dtype] = "float16",
861-
):
862-
r"""Create auxiliary data structures for batch decode for multiple forward calls
863-
within the same decode step.
864-
865-
Parameters
866-
----------
867-
indptr : torch.Tensor
868-
The indptr of the paged kv cache, shape: ``[batch_size + 1]``
869-
indices_host : torch.Tensor
870-
The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]``
871-
last_page_len : torch.Tensor
872-
The number of entries in the last page of each request in the paged kv
873-
cache, shape: ``[batch_size]``
874-
num_qo_heads : int
875-
The number of query/output heads
876-
num_kv_heads : int
877-
The number of key/value heads
878-
head_dim : int
879-
The dimension of the heads
880-
page_size : int
881-
The page size of the paged kv cache
882-
pos_encoding_mode : str
883-
Whether to apply RoPE on-the-fly inside attention kernels, could be
884-
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
885-
data_type : Union[str, torch.dtype]
886-
The data type of the paged kv cache
887-
888-
Note
889-
----
890-
The :meth:`begin_forward` method should be called before any :meth:`forward` or
891-
:meth:`forward_return_lse` calls, auxiliary data structures will be created
892-
during this call and cached for multiple forward calls.
893-
894-
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
895-
is not equal to ``num_kv_heads``, the function will use
896-
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
897-
"""
898-
899-
self._paged_kv_indptr_buf[: len(indptr)] = indptr
900-
self._paged_kv_indices_buf[: len(indices)] = indices
901-
self._paged_kv_last_page_len_buf[: len(last_page_len)] = last_page_len
902-
903-
batch_size = len(indptr) - 1
904-
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
905-
empty_data = torch.empty(
906-
0,
907-
dtype=(
908-
getattr(torch, data_type) if isinstance(data_type, str) else data_type
909-
),
910-
)
911-
self._wrapper.begin_forward(
912-
self._workspace_buffer,
913-
indptr,
914-
last_page_len,
915-
batch_size,
916-
num_qo_heads,
917-
num_kv_heads,
918-
head_dim,
919-
page_size,
920-
PosEncodingMode[pos_encoding_mode].value,
921-
empty_data,
922-
)
923-
924-
def end_forward(self):
925-
r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
926-
self._wrapper.end_forward()
927-
928-
def forward(
929-
self,
930-
q: torch.Tensor,
931-
paged_kv_data: torch.Tensor,
932-
pos_encoding_mode: str = "NONE",
933-
q_scale: Optional[float] = None,
934-
k_scale: Optional[float] = None,
935-
v_scale: Optional[float] = None,
936-
sm_scale: Optional[float] = None,
937-
rope_scale: Optional[float] = None,
938-
rope_theta: Optional[float] = None,
939-
):
940-
r"""Compute batch decode attention between query and paged kv cache.
941-
942-
Parameters
943-
----------
944-
q : torch.Tensor
945-
The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``
946-
paged_kv_data : torch.Tensor
947-
A 5-D tensor of the reserved paged kv-cache data, shape:
948-
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
949-
:attr:`kv_layout` is ``NHD``, or
950-
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
951-
:attr:`kv_layout` is ``HND``.
952-
pos_encoding_mode : str
953-
Whether to apply RoPE on-the-fly inside attention kernels, could be
954-
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
955-
q_scale : Optional[float]
956-
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
957-
k_scale : Optional[float]
958-
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
959-
v_scale : Optional[float]
960-
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
961-
sm_scale : Optional[float]
962-
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
963-
rope_scale : Optional[float]
964-
The scale used in RoPE interpolation, if not provided, will be set to
965-
``1.0``.
966-
rope_theta : Optional[float]
967-
The theta used in RoPE, if not provided, will be set to ``1e4``.
968-
969-
Returns
970-
-------
971-
torch.Tensor
972-
The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
973-
"""
974-
check_pos_encoding_mode(pos_encoding_mode)
975-
if sm_scale is None:
976-
head_dim = q.shape[-1]
977-
sm_scale = 1.0 / math.sqrt(head_dim)
978-
if q_scale is not None:
979-
sm_scale *= q_scale
980-
if k_scale is not None:
981-
sm_scale *= k_scale
982-
if rope_scale is None:
983-
rope_scale = 1.0
984-
if rope_theta is None:
985-
rope_theta = 1e4
986-
987-
paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
988-
out = self._wrapper.forward(
989-
q,
990-
paged_kv_data,
991-
self._paged_kv_indptr_buf,
992-
self._paged_kv_indices_buf,
993-
self._paged_kv_last_page_len_buf,
994-
PosEncodingMode[pos_encoding_mode].value,
995-
sm_scale,
996-
rope_scale,
997-
rope_theta,
998-
False,
999-
)[0]
1000-
if v_scale is not None:
1001-
out *= v_scale
1002-
return out
1003-
1004-
def forward_return_lse(
1005-
self,
1006-
q: torch.Tensor,
1007-
paged_kv_data: torch.Tensor,
1008-
pos_encoding_mode: str = "NONE",
1009-
q_scale: Optional[float] = None,
1010-
k_scale: Optional[float] = None,
1011-
v_scale: Optional[float] = None,
1012-
sm_scale: Optional[float] = None,
1013-
rope_scale: Optional[float] = None,
1014-
rope_theta: Optional[float] = None,
1015-
):
1016-
r"""Compute batch decode attention with paged kv cache, return attention output
1017-
and logsumexp of attention scores.
1018-
1019-
Parameters
1020-
----------
1021-
q : torch.Tensor
1022-
The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``
1023-
paged_kv_data : torch.Tensor
1024-
A 5-D tensor of the reserved paged kv-cache data, shape:
1025-
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
1026-
:attr:`kv_layout` is ``NHD``, or
1027-
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
1028-
:attr:`kv_layout` is ``HND``.
1029-
pos_encoding_mode : str
1030-
Whether to apply RoPE on-the-fly inside attention kernels, could be
1031-
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
1032-
q_scale : Optional[float]
1033-
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
1034-
k_scale : Optional[float]
1035-
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
1036-
v_scale : Optional[float]
1037-
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
1038-
sm_scale : Optional[float]
1039-
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
1040-
rope_scale : Optional[float]
1041-
The scale used in RoPE interpolation, if not provided, will be set to
1042-
``1.0``.
1043-
rope_theta : Optional[float]
1044-
The theta used in RoPE, if not provided, will be set to ``1e4``.
1045-
1046-
Returns
1047-
-------
1048-
V : torch.Tensor
1049-
The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
1050-
S : torch.Tensor
1051-
The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``.
1052-
1053-
Notes
1054-
-----
1055-
Please refer to the :ref:`tutorial <recursive-attention>` for a detailed
1056-
explanation of the log-sum-exp function and attention states.
1057-
"""
1058-
check_pos_encoding_mode(pos_encoding_mode)
1059-
if sm_scale is None:
1060-
head_dim = q.shape[-1]
1061-
sm_scale = 1.0 / math.sqrt(head_dim)
1062-
if q_scale is not None:
1063-
sm_scale *= q_scale
1064-
if k_scale is not None:
1065-
sm_scale *= k_scale
1066-
if rope_scale is None:
1067-
rope_scale = 1.0
1068-
if rope_theta is None:
1069-
rope_theta = 1e4
1070-
paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
1071-
V, s = self._wrapper.forward(
1072-
q,
1073-
paged_kv_data,
1074-
self._paged_kv_indptr_buf,
1075-
self._paged_kv_indices_buf,
1076-
self._paged_kv_last_page_len_buf,
1077-
self._batch_size,
1078-
self._nnz_pages,
1079-
PosEncodingMode[pos_encoding_mode].value,
1080-
sm_scale,
1081-
rope_scale,
1082-
rope_theta,
827+
super().__init__(
828+
workspace_buffer,
829+
kv_layout,
1083830
True,
1084831
indptr_buffer,
1085832
indices_buffer,
1086833
last_page_len_buffer,
1087834
)
1088-
if v_scale is not None:
1089-
V *= v_scale
1090-
return V, s

0 commit comments

Comments
 (0)