Skip to content

Conversation

@zmxdream
Copy link
Contributor

hi, i take a look at the code of column_sum_reduce, i have 2 questions:

  1. the goal of column_sum_reduce is to get the column sum of inp matrix with shape[rows, width] and the result shape should be [width],right ? It seems that the judgment condition of pos is not suitable
  2. the implementation of cuda kernel based on the asumption that, the thread with same threadIdx.y will group into a thread_block_tile, the blockDim is (32,32), i read the nvidia document https://on-demand.gputechconf.com/gtc/2017/presentation/s7622-Kyrylo-perelygin-robust-and-scalable-cuda.pdf, THREAD BLOCK TILE is a subset of threads of a thread block, divided into tiles in row-major order. doesn't it mean thread with the same threadIdx.x will group into a thread_block_tile ?
    thanks !!!!

hi, i take a look at the code of column_sum_reduce, i have 2 questions:
   1. the goal of column_sum_reduce is to get the column sum of inp matrix with shape[rows, width] and the result shape should be [width],right ? It seems that the judgment condition of pos is not suitable
   2. the implementation of cuda kernel based on the asumption that, the thread with same threadIdx.y will group into a thread_block_tile, the blockDim is (32,32), i read the nvidia document https://on-demand.gputechconf.com/gtc/2017/presentation/s7622-Kyrylo-perelygin-robust-and-scalable-cuda.pdf, THREAD BLOCK TILE is a subset of threads of a thread block, divided into tiles in row-major order. doesn't it mean thread with the same threadIdx.x will group into a thread_block_tile ?
thanks !!!!
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < (rows * width)) out[pos] = sum;
if (pos < width) out[pos] = sum;
Copy link
Contributor

@RezaYazdaniAminabadi RezaYazdaniAminabadi Feb 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! I would say it still was working when the hidden dimension was dividable by 32, however, it would have caused a memory leak for when the hidden is not dividable by 32!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! thanks for your approval!!

@RezaYazdaniAminabadi
Copy link
Contributor

hi, i take a look at the code of column_sum_reduce, i have 2 questions:

  1. the goal of column_sum_reduce is to get the column sum of inp matrix with shape[rows, width] and the result shape should be [width],right ? It seems that the judgment condition of pos is not suitable
  2. the implementation of cuda kernel based on the asumption that, the thread with same threadIdx.y will group into a thread_block_tile, the blockDim is (32,32), i read the nvidia document https://on-demand.gputechconf.com/gtc/2017/presentation/s7622-Kyrylo-perelygin-robust-and-scalable-cuda.pdf, THREAD BLOCK TILE is a subset of threads of a thread block, divided into tiles in row-major order. doesn't it mean thread with the same threadIdx.x will group into a thread_block_tile ?
    thanks !!!!

Hi @zmx19951103

Thanks for fixing this bug. Regarding your second question, I think both x and y dimensions are assigned to different thread_block tiles, however, since this is a 2-dimensional tile, we are just using the the threadIx.y for saving the output after all is reduced across each tile (here ) whose got the same y index and x index changes from 0 to 31. So, what you are saying is also true, and this is also our assumption when reducing the elements in a row!
I hope this answered your question.
Thanks,
Reza

@RezaYazdaniAminabadi RezaYazdaniAminabadi merged commit 937c5ce into deepspeedai:master Feb 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants