Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/gather_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ __global__ void _GatherKernel(
T* output_data,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_index = 0;
int input_block_index, block_offset;
output_block_size.divmod(id, input_block_index, block_offset);
int indices_index, offset;
Expand All @@ -47,7 +46,10 @@ __global__ void _GatherKernel(
return;
}

input_index = input_block_index * input_block_size + idx * block_size.d_ + offset;
// Use int64_t to avoid overflow when the input tensor has more than
// INT32_MAX elements (e.g. a [262144, 8960] embedding table = 2.35B).
int64_t input_index = static_cast<int64_t>(input_block_index) * input_block_size +
idx * block_size.d_ + offset;
output_data[id] = input_data[input_index];
}

Expand Down
Loading