@@ -128,18 +128,36 @@ struct Identity<T, ComplexSum> {
128128 static constexpr T value = {0 , 0 };
129129};
130130
131+ template <typename T, typename Op, bool UseKahan>
132+ struct BlockPrefixCallbackOp ;
133+
131134template <typename T, typename Op>
132- struct BlockPrefixCallbackOp {
135+ struct BlockPrefixCallbackOp <T, Op, false > {
133136 // Running prefix
134137 T running_total_;
135- T compensation_;
136138 Op op_;
137139
138140 __device__ BlockPrefixCallbackOp (T identity, Op op)
139- : running_total_(identity), compensation_(identity), op_(op) {}
141+ : running_total_(identity), op_(op) {}
140142
141143 // Callback operator to be entered by the first warp of threads in the block.
142144 // tid 0 is responsible for returning a value for seeding the block-wide scan.
145+ __device__ T operator ()(T block_aggregate) {
146+ const T old_prefix = running_total_;
147+ running_total_ = op_ (running_total_, block_aggregate);
148+ return old_prefix;
149+ }
150+ };
151+
152+ template <typename T, typename Op>
153+ struct BlockPrefixCallbackOp <T, Op, true > {
154+ T running_total_;
155+ T compensation_;
156+ Op op_;
157+
158+ __device__ BlockPrefixCallbackOp (T identity, Op op)
159+ : running_total_(identity), compensation_(static_cast <T>(0.0 )), op_(op) {}
160+
143161 __device__ T operator ()(T block_aggregate) {
144162 T old_prefix = running_total_;
145163
@@ -155,20 +173,23 @@ struct BlockPrefixCallbackOp {
155173};
156174
157175template <typename T>
158- struct BlockPrefixCallbackOp <T, LogAddExp> {
176+ struct BlockPrefixCallbackOp <T, LogAddExp, true > {
159177 T max_so_far_;
160178 T scaled_sum_;
161179 T compensation_;
162180 LogAddExp op_;
163181
164182 __device__ BlockPrefixCallbackOp (T identity, LogAddExp op)
165- : max_so_far_(identity), scaled_sum_(0.0 ), compensation_(0.0 ), op_(op) {}
183+ : max_so_far_(identity),
184+ scaled_sum_(static_cast <T>(0.0 )),
185+ compensation_(static_cast <T>(0.0 )),
186+ op_(op) {}
166187
167188 __device__ T operator ()(T block_aggregate) {
168189 if (scaled_sum_ == 0.0 ) {
169190 max_so_far_ = block_aggregate;
170- scaled_sum_ = 1.0 ;
171- compensation_ = 0.0 ;
191+ scaled_sum_ = static_cast <T>( 1.0 ) ;
192+ compensation_ = static_cast <T>( 0.0 ) ;
172193 return std::numeric_limits<T>::lowest ();
173194 }
174195
@@ -195,15 +216,19 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
195216 }
196217};
197218
198- template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
219+ template <typename T,
220+ int BLOCK_THREADS,
221+ int ITEMS_PER_THREAD,
222+ typename Op,
223+ bool UseKahan>
199224__global__ void BlockScanKernel (T* d_out,
200225 const T* d_in,
201226 int64_t grid_size,
202227 int64_t scan_size,
203228 bool exclusive,
204229 Op op) {
205230 using MT = typename phi::dtype::MPTypeTrait<T>::Type;
206- using CallbackOp = BlockPrefixCallbackOp<MT, Op>;
231+ using CallbackOp = BlockPrefixCallbackOp<MT, Op, UseKahan >;
207232
208233 // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
209234 using BlockLoadT = cub::
@@ -350,14 +375,30 @@ void ScanKernel(const Context& dev_ctx,
350375 }
351376 }
352377
378+ // When scan_size is large, switch to Kahan scan to get better precision
379+ constexpr int64_t KAHAN_SWITCH_LENGTH = 1 << 16 ;
380+
353381 // Do scan
354382 if (!transpose && !reverse) {
355- BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
356- out_data, in_data, grid_size, scan_size, exclusive, op);
357-
383+ if (scan_size > KAHAN_SWITCH_LENGTH) {
384+ BlockScanKernel<T, 128 , 4 , Op, true >
385+ <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
386+ out_data, in_data, grid_size, scan_size, exclusive, op);
387+ } else {
388+ BlockScanKernel<T, 128 , 4 , Op, false >
389+ <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
390+ out_data, in_data, grid_size, scan_size, exclusive, op);
391+ }
358392 } else {
359- BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
360- next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
393+ if (scan_size > KAHAN_SWITCH_LENGTH) {
394+ BlockScanKernel<T, 128 , 4 , Op, true >
395+ <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
396+ next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
397+ } else {
398+ BlockScanKernel<T, 128 , 4 , Op, false >
399+ <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
400+ next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
401+ }
361402 }
362403 swap_ptr (next_in_data, next_out_data);
363404
0 commit comments