@@ -88,40 +88,39 @@ struct BeamSearchParams {
8888 // but we have to match trt-llm API.
8989};
9090
91- void launchMHA (
92- cudaDeviceProp const & prop, uint32_t const nbKHeads,
91+ void launchMHA (cudaDeviceProp const & prop, uint32_t const nbKHeads,
9392#if SLIDING_WINDOW
94- uint32_t slidingWinSize,
93+ uint32_t slidingWinSize,
9594#endif
96- float qScale, OutputHead* output,
95+ float qScale, OutputHead* output,
9796#if LOW_PREC_OUTPUT
98- float const * rcpOutScale,
97+ float const * rcpOutScale,
9998#endif
10099#if USE_INPUT_KV
101- InputHead const * qkv,
100+ InputHead const * qkv,
102101#if ROPE_STYLE != 0
103- Vec<float , validElemsPerHead> const * ropeCosSin,
102+ Vec<float , validElemsPerHead> const * ropeCosSin,
104103#endif
105104#else
106- InputHead const * q,
107- #endif
108- float const * attentionSinks, // [headGrpSize]
109- GMemCacheHead* kCacheVLLM , GMemCacheHead* vCacheVLLM,
110- KVCachePageIndex const *
111- kvCachePageList, // device pointer. shape:
112- // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
113- uint32_t maxSeqLen, uint32_t const * seqLen,
105+ InputHead const * q,
106+ #endif
107+ float const * attentionSinks, // [headGrpSize]
108+ GMemCacheHead* kCacheVLLM , GMemCacheHead* vCacheVLLM,
109+ KVCachePageIndex const *
110+ kvCachePageList, // device pointer. shape:
111+ // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
112+ uint32_t maxSeqLen, uint32_t const * seqLen,
114113#if BEAM_WIDTH > 1
115- BeamSearchParams const & beamSearchParams,
114+ BeamSearchParams const & beamSearchParams,
116115#endif
117- uint32_t batchSize,
118- float const * __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
119- // Used only for int8/fp8 KV cache.
116+ uint32_t batchSize,
117+ float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
118+ // Used only for int8/fp8 KV cache.
120119#if SPEC_DEC
121- SpecDecParams const & specDecParams,
120+ SpecDecParams const & specDecParams,
122121#endif
123- uint32_t * semaphores, void * scratch, bool enable_pdl, uint64_t kv_stride_page,
124- uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
122+ uint32_t * semaphores, void * scratch, bool enable_pdl, uint64_t kv_stride_page,
123+ uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
125124
126125void launchMHAFlashInfer (uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize,
127126 float qScale, OutputHead* output,
@@ -131,7 +130,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32
131130 InputHead const * q, float const * attentionSinks, GMemCacheHead* kCacheVLLM ,
132131 GMemCacheHead* vCacheVLLM, KVCachePageIndex const * kvCachePageList,
133132 uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
134- float const * __restrict__ kvCacheScale,
133+ float kvCacheScale,
135134#if SPEC_DEC
136135 uint32_t qSeqLen, uint32_t const * qCuSeqLens, MaskType const * mask,
137136#endif
@@ -166,8 +165,8 @@ void launchHopperF8MHA(
166165 BeamSearchParams const & beamSearchParams,
167166#endif
168167 uint32_t batchSize,
169- float const * __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
170- // Used only for int8/fp8 KV cache.
168+ float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
169+ // Used only for int8/fp8 KV cache.
171170#if SPEC_DEC
172171 SpecDecParams const & specDecParams,
173172#endif
@@ -181,28 +180,26 @@ void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads
181180 InputHead const * q, float const * attentionSinks,
182181 GMemCacheHead* kCacheVLLM , GMemCacheHead* vCacheVLLM,
183182 KVCachePageIndex const * kvCachePageList, uint32_t maxSeqLen,
184- uint32_t const * seqLen, uint32_t batchSize,
185- float const * __restrict__ kvCacheScale,
183+ uint32_t const * seqLen, uint32_t batchSize, float kvCacheScale,
186184#if SPEC_DEC
187185 uint32_t qSeqLen, uint32_t const * qCuSeqLens, MaskType const * mask,
188186#endif
189187 uint32_t * semaphores, void * scratch, bool enable_pdl,
190188 uint64_t kv_stride_page, uint64_t kv_stride_token,
191189 uint64_t kv_stride_head, cudaStream_t stream);
192190
193- void launchMLA (
194- cudaDeviceProp const & prop,
195- uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
196- float qScale, OutputHead* output, InputHead const * q, GMemCacheHead* kCacheVLLM ,
197- GMemCacheHead* vCacheVLLM,
198- KVCachePageIndex const *
199- kvCachePageList, // device pointer. shape:
200- // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
201- // [batchSize][maxNbPagesPerSeq] (Layout 1)
202- uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
203- float const * __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
204- // Used only for int8/fp8 KV cache.
205- uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
191+ void launchMLA (cudaDeviceProp const & prop,
192+ uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
193+ float qScale, OutputHead* output, InputHead const * q, GMemCacheHead* kCacheVLLM ,
194+ GMemCacheHead* vCacheVLLM,
195+ KVCachePageIndex const *
196+ kvCachePageList, // device pointer. shape:
197+ // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
198+ // (Layout 0) or [batchSize][maxNbPagesPerSeq] (Layout 1)
199+ uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
200+ float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
201+ // Used only for int8/fp8 KV cache.
202+ uint32_t * semaphores, void * scratch, bool enable_pdl, cudaStream_t stream);
206203
207204void launchMLAFlashInfer (
208205 uint32_t multiProcessorCount,
@@ -214,8 +211,8 @@ void launchMLAFlashInfer(
214211 // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
215212 // [batchSize][maxNbPagesPerSeq] (Layout 1)
216213 uint32_t maxSeqLen, uint32_t const * seqLen, uint32_t batchSize,
217- float const * __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache.
218- // Used only for int8/fp8 KV cache.
214+ float kvCacheScale, // Device memory scalar. Same scale for K and V cache.
215+ // Used only for int8/fp8 KV cache.
219216 uint32_t * semaphores, void * scratch, bool enable_pdl, uint64_t kv_stride_page,
220217 uint64_t kv_stride_token, uint64_t kv_stride_head, cudaStream_t stream);
221218
0 commit comments