@@ -824,267 +824,11 @@ def __init__(
824824 should be large enough to store the maximum batch size (``[max_batch_size]``)
825825 during the lifecycle of this wrapper.
826826 """
827- check_kv_layout (kv_layout )
828- self ._kv_layout = kv_layout
829- self ._workspace_buffer = workspace_buffer
830- max_batch_size = len (last_page_len_buffer )
831- self ._wrapper = _kernels .CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper (
832- TensorLayout [kv_layout ].value ,
833- max_batch_size ,
834- )
835- self ._paged_kv_indptr_buf = indptr_buffer
836- self ._paged_kv_indices_buf = indices_buffer
837- self ._paged_kv_last_page_len_buf = last_page_len_buffer
838-
839- def reset_workspace_buffer (self , new_workspace_buffer : torch .Tensor ):
840- r"""Reset the workspace buffer.
841-
842- Parameters
843- ----------
844- new_workspace_buffer : torch.Tensor
845- The new workspace buffer, the device of the new workspace buffer should
846- be the same as the device of the input tensors.
847- """
848- self ._workspace_buffer = new_workspace_buffer
849-
850- def begin_forward (
851- self ,
852- indptr : torch .Tensor ,
853- indices : torch .Tensor ,
854- last_page_len : torch .Tensor ,
855- num_qo_heads : int ,
856- num_kv_heads : int ,
857- head_dim : int ,
858- page_size : int ,
859- pos_encoding_mode : str = "NONE" ,
860- data_type : Union [str , torch .dtype ] = "float16" ,
861- ):
862- r"""Create auxiliary data structures for batch decode for multiple forward calls
863- within the same decode step.
864-
865- Parameters
866- ----------
867- indptr : torch.Tensor
868- The indptr of the paged kv cache, shape: ``[batch_size + 1]``
869- indices_host : torch.Tensor
870- The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]``
871- last_page_len : torch.Tensor
872- The number of entries in the last page of each request in the paged kv
873- cache, shape: ``[batch_size]``
874- num_qo_heads : int
875- The number of query/output heads
876- num_kv_heads : int
877- The number of key/value heads
878- head_dim : int
879- The dimension of the heads
880- page_size : int
881- The page size of the paged kv cache
882- pos_encoding_mode : str
883- Whether to apply RoPE on-the-fly inside attention kernels, could be
884- ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
885- data_type : Union[str, torch.dtype]
886- The data type of the paged kv cache
887-
888- Note
889- ----
890- The :meth:`begin_forward` method should be called before any :meth:`forward` or
891- :meth:`forward_return_lse` calls, auxiliary data structures will be created
892- during this call and cached for multiple forward calls.
893-
894- The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
895- is not equal to ``num_kv_heads``, the function will use
896- `grouped query attention <https://arxiv.org/abs/2305.13245>`_.
897- """
898-
899- self ._paged_kv_indptr_buf [: len (indptr )] = indptr
900- self ._paged_kv_indices_buf [: len (indices )] = indices
901- self ._paged_kv_last_page_len_buf [: len (last_page_len )] = last_page_len
902-
903- batch_size = len (indptr ) - 1
904- # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
905- empty_data = torch .empty (
906- 0 ,
907- dtype = (
908- getattr (torch , data_type ) if isinstance (data_type , str ) else data_type
909- ),
910- )
911- self ._wrapper .begin_forward (
912- self ._workspace_buffer ,
913- indptr ,
914- last_page_len ,
915- batch_size ,
916- num_qo_heads ,
917- num_kv_heads ,
918- head_dim ,
919- page_size ,
920- PosEncodingMode [pos_encoding_mode ].value ,
921- empty_data ,
922- )
923-
924- def end_forward (self ):
925- r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
926- self ._wrapper .end_forward ()
927-
928- def forward (
929- self ,
930- q : torch .Tensor ,
931- paged_kv_data : torch .Tensor ,
932- pos_encoding_mode : str = "NONE" ,
933- q_scale : Optional [float ] = None ,
934- k_scale : Optional [float ] = None ,
935- v_scale : Optional [float ] = None ,
936- sm_scale : Optional [float ] = None ,
937- rope_scale : Optional [float ] = None ,
938- rope_theta : Optional [float ] = None ,
939- ):
940- r"""Compute batch decode attention between query and paged kv cache.
941-
942- Parameters
943- ----------
944- q : torch.Tensor
945- The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``
946- paged_kv_data : torch.Tensor
947- A 5-D tensor of the reserved paged kv-cache data, shape:
948- ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
949- :attr:`kv_layout` is ``NHD``, or
950- ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
951- :attr:`kv_layout` is ``HND``.
952- pos_encoding_mode : str
953- Whether to apply RoPE on-the-fly inside attention kernels, could be
954- ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
955- q_scale : Optional[float]
956- The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
957- k_scale : Optional[float]
958- The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
959- v_scale : Optional[float]
960- The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
961- sm_scale : Optional[float]
962- The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
963- rope_scale : Optional[float]
964- The scale used in RoPE interpolation, if not provided, will be set to
965- ``1.0``.
966- rope_theta : Optional[float]
967- The theta used in RoPE, if not provided, will be set to ``1e4``.
968-
969- Returns
970- -------
971- torch.Tensor
972- The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
973- """
974- check_pos_encoding_mode (pos_encoding_mode )
975- if sm_scale is None :
976- head_dim = q .shape [- 1 ]
977- sm_scale = 1.0 / math .sqrt (head_dim )
978- if q_scale is not None :
979- sm_scale *= q_scale
980- if k_scale is not None :
981- sm_scale *= k_scale
982- if rope_scale is None :
983- rope_scale = 1.0
984- if rope_theta is None :
985- rope_theta = 1e4
986-
987- paged_kv_data = expand_5d (paged_kv_data , self ._kv_layout )
988- out = self ._wrapper .forward (
989- q ,
990- paged_kv_data ,
991- self ._paged_kv_indptr_buf ,
992- self ._paged_kv_indices_buf ,
993- self ._paged_kv_last_page_len_buf ,
994- PosEncodingMode [pos_encoding_mode ].value ,
995- sm_scale ,
996- rope_scale ,
997- rope_theta ,
998- False ,
999- )[0 ]
1000- if v_scale is not None :
1001- out *= v_scale
1002- return out
1003-
1004- def forward_return_lse (
1005- self ,
1006- q : torch .Tensor ,
1007- paged_kv_data : torch .Tensor ,
1008- pos_encoding_mode : str = "NONE" ,
1009- q_scale : Optional [float ] = None ,
1010- k_scale : Optional [float ] = None ,
1011- v_scale : Optional [float ] = None ,
1012- sm_scale : Optional [float ] = None ,
1013- rope_scale : Optional [float ] = None ,
1014- rope_theta : Optional [float ] = None ,
1015- ):
1016- r"""Compute batch decode attention with paged kv cache, return attention output
1017- and logsumexp of attention scores.
1018-
1019- Parameters
1020- ----------
1021- q : torch.Tensor
1022- The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``
1023- paged_kv_data : torch.Tensor
1024- A 5-D tensor of the reserved paged kv-cache data, shape:
1025- ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
1026- :attr:`kv_layout` is ``NHD``, or
1027- ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
1028- :attr:`kv_layout` is ``HND``.
1029- pos_encoding_mode : str
1030- Whether to apply RoPE on-the-fly inside attention kernels, could be
1031- ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
1032- q_scale : Optional[float]
1033- The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
1034- k_scale : Optional[float]
1035- The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
1036- v_scale : Optional[float]
1037- The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
1038- sm_scale : Optional[float]
1039- The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
1040- rope_scale : Optional[float]
1041- The scale used in RoPE interpolation, if not provided, will be set to
1042- ``1.0``.
1043- rope_theta : Optional[float]
1044- The theta used in RoPE, if not provided, will be set to ``1e4``.
1045-
1046- Returns
1047- -------
1048- V : torch.Tensor
1049- The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``.
1050- S : torch.Tensor
1051- The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``.
1052-
1053- Notes
1054- -----
1055- Please refer to the :ref:`tutorial <recursive-attention>` for a detailed
1056- explanation of the log-sum-exp function and attention states.
1057- """
1058- check_pos_encoding_mode (pos_encoding_mode )
1059- if sm_scale is None :
1060- head_dim = q .shape [- 1 ]
1061- sm_scale = 1.0 / math .sqrt (head_dim )
1062- if q_scale is not None :
1063- sm_scale *= q_scale
1064- if k_scale is not None :
1065- sm_scale *= k_scale
1066- if rope_scale is None :
1067- rope_scale = 1.0
1068- if rope_theta is None :
1069- rope_theta = 1e4
1070- paged_kv_data = expand_5d (paged_kv_data , self ._kv_layout )
1071- V , s = self ._wrapper .forward (
1072- q ,
1073- paged_kv_data ,
1074- self ._paged_kv_indptr_buf ,
1075- self ._paged_kv_indices_buf ,
1076- self ._paged_kv_last_page_len_buf ,
1077- self ._batch_size ,
1078- self ._nnz_pages ,
1079- PosEncodingMode [pos_encoding_mode ].value ,
1080- sm_scale ,
1081- rope_scale ,
1082- rope_theta ,
827+ super ().__init__ (
828+ workspace_buffer ,
829+ kv_layout ,
1083830 True ,
1084831 indptr_buffer ,
1085832 indices_buffer ,
1086833 last_page_len_buffer ,
1087834 )
1088- if v_scale is not None :
1089- V *= v_scale
1090- return V , s
0 commit comments