Skip to content
48 changes: 48 additions & 0 deletions src/cudamatrix/cu-kernels-ansi.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,54 @@ void cudaD_add_mat_repeated(dim3 Gr, dim3 Bl, double alpha, const double *src,
MatrixDim src_dim, double *dst, MatrixDim dst_dim);
void cudaF_add_mat_repeated(dim3 Gr, dim3 Bl, float alpha, const float *src,
MatrixDim src_dim, float *dst, MatrixDim dst_dim);
void cudaD_max_mat_blocks(dim3 Gr, dim3 Bl,
const double *src, double *dst, double *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_,
int A_tran);
void cudaF_max_mat_blocks(dim3 Gr, dim3 Bl,
const float *src, float *dst, float *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_,
int A_tran);
void cudaD_max_mat_blocks_back(dim3 Gr, dim3 Bl,
const double *src, double *dst, double *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_);
void cudaF_max_mat_blocks_back(dim3 Gr, dim3 Bl,
const float *src, float *dst, float *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_);
void cudaD_add_mat_diag_vec(dim3 Gr, dim3 Bl, double alpha, double *mat,
MatrixDim mat_dim, const double *mat2,
int mat2_row_stride, int mat2_col_stride,
Expand Down
228 changes: 228 additions & 0 deletions src/cudamatrix/cu-kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,149 @@ static void _add_mat_blocks_trans(Real alpha, const Real* src,
}
}

template<typename Real>
__global__
static void _max_mat_blocks(const Real *src, Real *dst, Real *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
int32_cuda k = blockIdx.z * blockDim.z + threadIdx.z;
int32_cuda num_pools_h = 1 + (input_h_dim_ - pool_h_size_) / pool_h_step_;
int32_cuda num_pools_f = 1 + (input_f_dim_ - pool_f_size_) / pool_f_step_;

// initialize the temporary maximum value and its index in each pool
int32_cuda max_row = i * pool_t_step_;
int32_cuda max_col = j * pool_h_step_ * input_f_dim_ + k * pool_f_step_;
int32_cuda max_value = src[max_row * input_h_dim_ * input_f_dim_ + max_col];

// loop over all the elements in each pool to find the maximum one,
// and record its index.

for (int32_cuda t = 0; t < pool_t_size_; t += stride) {
// the index of row in *src
int32_cuda idx_row = i * pool_t_step_ + t;

for (int32_cuda h = 0; h < pool_h_size_; h++) {
for (int32_cuda f = 0; f < pool_f_size_; f++) {
// the index of column in *src
int32_cuda idx_col = (j * pool_h_step_ + h) * input_f_dim_ + k * pool_f_step_ + f;

if (src[idx_row * input_h_dim_ * input_f_dim_ + idx_col] > max_value) {
max_row = idx_row;
max_col = idx_col;
max_value = src[max_row * input_h_dim_ * input_f_dim_ + max_col];
}
}
}
}

dst[i * num_pools_h * num_pools_f + j * num_pools_f + k] = max_value;

// the index of indexes stored in vector 'index_max_'.
int32_cuda idx_in_idxmax = (i * num_pools_h + j) * num_pools_f + k;
index_max_[idx_in_idxmax] = max_row;
index_max_[idx_in_idxmax + 1] = max_col;
}

// this function is basicall the same as _max_mat_blocks, except it
// deal with the transpose matrix of *src. So the column and row index
// are exchanged.
template<typename Real>
__global__
static void _max_mat_blocks_trans(const Real *src, Real *dst, Real *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
int32_cuda k = blockIdx.z * blockDim.z + threadIdx.z;
int32_cuda num_pools_t = 1 + (input_t_dim_ - pool_t_size_) / pool_t_step_;
int32_cuda num_pools_h = 1 + (input_h_dim_ - pool_h_size_) / pool_h_step_;
int32_cuda num_pools_f = 1 + (input_f_dim_ - pool_f_size_) / pool_f_step_;

int32_cuda max_row = i * pool_t_step_;
int32_cuda max_col = j * pool_h_step_ * input_f_dim_ + k * pool_f_step_;
int32_cuda max_value = src[max_col * input_t_dim_ + max_row];

for (int32_cuda t = 0; t < pool_t_size_; t += stride) {
int32_cuda idx_row = i * pool_t_step_ + t;

for (int32_cuda h = 0; h < pool_h_size_; h++) {
for (int32_cuda f = 0; f < pool_f_size_; f++) {
int32_cuda idx_col = (j * pool_h_step_ + h) * input_f_dim_ + k * pool_f_step_ + f;

if (src[idx_col * input_t_dim_ + idx_row] > max_value) {
max_row = idx_row;
max_col = idx_col;
max_value = src[max_col * input_t_dim_ + max_row];
}
}
}
}

dst[(j * num_pools_f + k) * num_pools_t + i] = max_value;

int32_cuda idx_in_idxmax = (i * num_pools_h + j) * num_pools_f + k;
index_max_[idx_in_idxmax] = max_row;
index_max_[idx_in_idxmax + 1] = max_col;

}

template<typename Real>
__global__
static void _max_mat_blocks_back(const Real *src, Real *dst, Real *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
int32_cuda k = blockIdx.z * blockDim.z + threadIdx.z;
int32_cuda num_pools_h = 1 + (input_h_dim_ - pool_h_size_) / pool_h_step_;
int32_cuda num_pools_f = 1 + (input_f_dim_ - pool_f_size_) / pool_f_step_;

for (int32_cuda t = 0; t < pool_t_size_ ; t++) {
int32_cuda idx_row = i * pool_t_step_ + t;

for (int32_cuda h = 0; h < pool_h_size_ ; h++) {
for (int32_cuda f = 0; f < pool_f_size_; f++) {
int32_cuda idx_col = (j * pool_h_step_ + h) * input_f_dim_ + k * pool_f_step_ + f;
int32_cuda idx_in_idxmax = (i * num_pools_h + j) * num_pools_f + k;

if (idx_row == index_max_[idx_in_idxmax] &&
idx_col == index_max_[idx_in_idxmax + 1] ||
dst[idx_row * input_h_dim_ * input_f_dim_ + idx_col] != 0) {
dst[idx_row * input_h_dim_ * input_f_dim_ + idx_col] =
src[i * num_pools_h * num_pools_f + j * num_pools_f + k];
} else {
dst[idx_row * input_h_dim_ * input_f_dim_ + idx_col] = 0;
}
}
}
}
}

template<typename Real>
__global__
static void _set_mat_mat_div_mat(const Real* A, const Real* B, const Real* C,
Expand Down Expand Up @@ -3957,6 +4100,48 @@ void cudaF_add_mat_repeated(dim3 Gr, dim3 Bl, float alpha, const float* src,
_add_mat_repeated<<<Gr,Bl>>>(alpha, src, src_dim, dst, dst_dim);
}

void cudaF_max_mat_blocks(dim3 Gr, dim3 Bl,
const float *src, float *dst, float *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_,
int A_trans) {
if (A_trans) {
_max_mat_blocks_trans<<<Gr,Bl>>>(src, dst, index_max_, stride,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
} else {
_max_mat_blocks<<<Gr,Bl>>>(src, dst, index_max_, stride,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
}
}

void cudaF_max_mat_blocks_back(dim3 Gr, dim3 Bl,
const float *src, float *dst, float *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
_max_mat_blocks_back<<<Gr,Bl>>>(src, dst, index_max_,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
}

void cudaF_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B,
const float *C, float *dst, MatrixDim d,
Expand Down Expand Up @@ -4661,6 +4846,49 @@ void cudaD_add_mat_repeated(dim3 Gr, dim3 Bl, double alpha, const double* src,
_add_mat_repeated<<<Gr,Bl>>>(alpha, src, src_dim, dst, dst_dim);
}

void cudaD_max_mat_blocks(dim3 Gr, dim3 Bl,
const double *src, double *dst, double *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_,
int A_trans) {
if (A_trans) {
_max_mat_blocks_trans<<<Gr,Bl>>>(src, dst, index_max_, stride,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
} else {
_max_mat_blocks<<<Gr,Bl>>>(src, dst, index_max_, stride,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
}
}

void cudaD_max_mat_blocks_back(dim3 Gr, dim3 Bl,
const double *src, double *dst, double *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
_max_mat_blocks_back<<<Gr,Bl>>>(src, dst, index_max_,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
}

void cudaD_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
const double *B, const double *C, double *dst,
MatrixDim d, int stride_a, int stride_b,
Expand Down
68 changes: 68 additions & 0 deletions src/cudamatrix/cu-kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,74 @@ inline void cuda_add_mat_repeated(dim3 Gr, dim3 Bl, float alpha,
float *dst, MatrixDim dst_dim) {
cudaF_add_mat_repeated(Gr, Bl, alpha, src, src_dim, dst, dst_dim);
}
inline void cuda_max_mat_blocks(dim3 Gr, dim3 Bl,
const double *src, double *dst, double *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_,
int A_trans) {
cudaD_max_mat_blocks(Gr, Bl, src, dst, index_max_, stride,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_, A_trans);
}
inline void cuda_max_mat_blocks(dim3 Gr, dim3 Bl,
const float *src, float *dst, float *index_max_,
const int32_cuda stride,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_,
int A_trans) {
cudaF_max_mat_blocks(Gr, Bl, src, dst, index_max_, stride,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_, A_trans);
}
inline void cuda_max_mat_blocks_back(dim3 Gr, dim3 Bl,
const double *src, double *dst, double *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
cudaD_max_mat_blocks_back(Gr, Bl, src, dst, index_max_,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
}
inline void cuda_max_mat_blocks_back(dim3 Gr, dim3 Bl,
const float *src, float *dst, float *index_max_,
const int32_cuda input_t_dim_,
const int32_cuda pool_t_size_,
const int32_cuda pool_t_step_,
const int32_cuda input_h_dim_,
const int32_cuda pool_h_size_,
const int32_cuda pool_h_step_,
const int32_cuda input_f_dim_,
const int32_cuda pool_f_size_,
const int32_cuda pool_f_step_) {
cudaF_max_mat_blocks_back(Gr, Bl, src, dst, index_max_,
input_t_dim_, pool_t_size_, pool_t_step_,
input_h_dim_, pool_h_size_, pool_h_step_,
input_f_dim_, pool_f_size_, pool_f_step_);
}
inline void cuda_add_mat_diag_vec(dim3 Gr, dim3 Bl, double alpha, double *mat,
MatrixDim mat_dim, const double *mat2,
int mat2_row_stride, int mat2_col_stride,
Expand Down
Loading