diff --git a/src/nnet3/nnet-general-component.cc b/src/nnet3/nnet-general-component.cc index 2720fbbd0bd..669e5112793 100644 --- a/src/nnet3/nnet-general-component.cc +++ b/src/nnet3/nnet-general-component.cc @@ -1563,9 +1563,8 @@ void* GeneralDropoutComponent::Propagate( int32 num_rows = out->NumRows(), dim_multiple = dim_ / block_dim_, num_rows_reshaped = num_rows * dim_multiple; - CuSubMatrix out_reshaped(out->Data(), block_dim_, - num_rows_reshaped, - num_rows_reshaped); + CuSubMatrix out_reshaped(out->Data(), num_rows_reshaped, + block_dim_, block_dim_); out_reshaped.MulRows(*mask, indexes->indexes); } else { out->MulRows(*mask, indexes->indexes); @@ -1602,9 +1601,9 @@ void GeneralDropoutComponent::Backprop( int32 num_rows = in_deriv->NumRows(), dim_multiple = dim_ / block_dim_, num_rows_reshaped = num_rows * dim_multiple; - CuSubMatrix in_deriv_reshaped(in_deriv->Data(), block_dim_, + CuSubMatrix in_deriv_reshaped(in_deriv->Data(), num_rows_reshaped, - num_rows_reshaped); + block_dim_, block_dim_); in_deriv_reshaped.MulRows(*mask, indexes->indexes); } else { in_deriv->MulRows(*mask, indexes->indexes);