Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/cudamatrix/cu-kernels-ansi.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// 2013 Hainan Xu
// 2013 Xiaohui Zhang
// 2013-2015 Guoguo Chen
// 2016 David Snyder

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down Expand Up @@ -179,6 +180,11 @@ void cudaF_equal_element_mask(dim3 Gr, dim3 Bl, const float *mat1,
const float *mat2, float *mask, MatrixDim mat1_dim,
int mat2_stride, int mask_stride);

void cudaF_compute_xvector_objf(dim3 Gr, dim3 Bl, const float *scores,
MatrixDim scores_dim, float *obfj_terms,
MatrixDim objf_dim, float *objf_derivs,
MatrixDim derivs_dim);

/*********************************************************
* double CUDA kernel calls
*/
Expand Down Expand Up @@ -302,6 +308,10 @@ void cudaD_copy_from_sp(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim
void cudaD_take_lower(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim d_in);
void cudaD_take_upper(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim d_in);
void cudaD_take_mean(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim d_in);
void cudaD_compute_xvector_objf(dim3 Gr, dim3 Bl, const double *scores,
MatrixDim scores_dim, double *obfj_terms,
MatrixDim objf_dim, double *objf_derivs,
MatrixDim derivs_dim);


// some mostly mixed-type kernels.
Expand Down Expand Up @@ -349,8 +359,6 @@ void cudaD_equal_element_mask(dim3 Gr, dim3 Bl, const double *mat1,
const double *mat2, double *mask, MatrixDim mat1_dim,
int mat2_stride, int mask_stride);



} // extern "C"

#endif // HAVE_CUDA
Expand Down
36 changes: 36 additions & 0 deletions src/cudamatrix/cu-kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// 2013 Hainan Xu
// 2013 Xiaohui Zhang
// 2013-2015 Guoguo Chen
// 2016 David Snyder

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2094,6 +2095,26 @@ static void _diff_xent(const int32_cuda* vec_tgt, Real* mat_net_out, Real* vec_l
}
}

template<typename Real>
__global__
static void _compute_xvector_objf(const Real* scores, MatrixDim scores_dim,
Real* obfj_terms, MatrixDim objf_dim,
Real* obfj_derivs, MatrixDim derivs_dim) {
int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
int32_cuda scores_index = i + j * scores_dim.stride;
Real K = 1.0 / (scores_dim.rows - 2.0);
Real L = scores[scores_index];
if (i < scores_dim.cols && j < scores_dim.rows && i < j) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid separately having to zero the upper triangle and the diagonal of the matrix, you might as well do it in this kernel. [i.e. and set it to kUndefined before calling this kernel].
However, I suppose this all becomes moot if you end up using Pegah's idea and rely on the SoftHinge kernel and a fixed scaling matrix.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking at it more, I think it's better to just do this in a cuda kernel.

Also, I still need to make kernels for the actual derivatives, which are somewhat nontrivial to compute in an efficient way... I don't think it's possible to use Pegah's idea to handle them.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only not-100%-trivial thing about the derivatives is the fact
that different parts of the matrix have different scaling factors. You
could probably compute the objf and derivs as follows using individual
kernels.

  • get matrix of scores.
  • apply fixed-scaling-1 to matrix of scores (to negate different-class)
  • compute soft-hinge function
  • Compute TraceMatMat of this matrix with a fixed scaling matrix
    fixed-scaling-2 (with 1/(num-rows-2) for different-class members) to get
    the objf
  • use the Sigmoid function to compute the derivative of the soft-hinge nonlinearity
  • Multiply the derivatives by fixed-scaling-1 * fixed-scaling-2. These
    are the derivatives of the objective function w.r.t. the raw scores.

There may be a few signs wrong here.
However, it would be more efficient to do all of the above in a single
kernel.
You can easily do it in the same kernel as computes the objective-function
terms. [do summation via matrix-sum though].

Dan

On Sat, Feb 13, 2016 at 3:58 PM, david-ryan-snyder <notifications@github.com

wrote:

In src/cudamatrix/cu-kernels.cu
#5 (comment):

@@ -2094,6 +2095,26 @@ static void _diff_xent(const int32_cuda* vec_tgt, Real* mat_net_out, Real* vec_l
}
}

+template
+global
+static void _compute_xvector_objf(const Real* scores, MatrixDim scores_dim,

  •                              Real\* obfj_terms, MatrixDim objf_dim,
    
  •                              Real\* obfj_derivs, MatrixDim derivs_dim) {
    
  • int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
  • int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
  • int32_cuda scores_index = i + j * scores_dim.stride;
  • Real K = 1.0 / (scores_dim.rows - 2.0);
  • Real L = scores[scores_index];
  • if (i < scores_dim.cols && j < scores_dim.rows && i < j) {

After looking at it more, I don't think it's better to just do this in a
cuda kernel.

Also, I still need to make kernels for the actual derivatives, which are
somewhat nontrivial to compute in an efficient way... I don't think it's
possible to use Pegah's idea to handle them.


Reply to this email directly or view it on GitHub
https://github.com/danpovey/kaldi/pull/5/files#r52833363.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're describing an alternative way to get the coefficients for the derivative terms. But, the kernel code above already does that.

On the CPU, the derivative wrt to S needs something like the following (NOTE: I'm ignoring peculiarities due to S being symmetric):

for i=0 ... N:
  for j = 0 ... N:
     v = xvectors(i)
     w = xvectors(j)
     deriv_S += C(i,j) * (v v' + w w')

Where C() is a coefficient dependent on whether or not the vectors at row i and j are from the same or different classes. This is what we calculated in the kernel above.

Each v,w pair results in its own matrix. I think this makes it harder to deal with in a single kernel. I think the easiest thing to do is to create an additional kernel that works like a modified form of matrix multiplication. Suppose V is the matrix of xvectors and D = NumCols(V). Then P = V' "times" V is the serialized outer product of each row of v. For example, P.Row(0) = Serialized( V.Row(0) * V.Row(0)'). In other words, p_{i,j} = v_{i, (j / D) % D} * v_{i, j % D}.

Once that is done, it should be more straightforward to calculate S_deriv += C(i, j) * (P.Row(i) + P.Row(j)) in parallel.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you are really thinking about this in the spirit of
backprop. The general principle is that you go forward computing the
objective function, and then you do a process that is roughly the
mirror-image of the forward process to backprop the derivatives through the
computation.

What I described was getting the derivatives of the objective function
w.r.t. the matrix of scores. After that you just have to do the reverse of
the forward operations to get the derivatives w.r.t. S and the matrix of
xvectors.

Dan

On Sat, Feb 13, 2016 at 4:51 PM, david-ryan-snyder <notifications@github.com

wrote:

In src/cudamatrix/cu-kernels.cu
#5 (comment):

@@ -2094,6 +2095,26 @@ static void _diff_xent(const int32_cuda* vec_tgt, Real* mat_net_out, Real* vec_l
}
}

+template
+global
+static void _compute_xvector_objf(const Real* scores, MatrixDim scores_dim,

  •                              Real\* obfj_terms, MatrixDim objf_dim,
    
  •                              Real\* obfj_derivs, MatrixDim derivs_dim) {
    
  • int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
  • int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
  • int32_cuda scores_index = i + j * scores_dim.stride;
  • Real K = 1.0 / (scores_dim.rows - 2.0);
  • Real L = scores[scores_index];
  • if (i < scores_dim.cols && j < scores_dim.rows && i < j) {

I think you're describing an alternative way to get the _coefficients _for
the derivative terms. But, the kernel code above already does that.

On the CPU, the derivative wrt to S needs something like the following
(NOTE: I'm ignoring peculiarities due to S being symmetric):

for i=0 ... N:
for j = 0 ... N:
v = xvectors(i)
w = xvectors(j)
deriv_S += C(i,j) * (v v' + w w')

Where C() is a coefficient dependent on whether or not the vectors at row
i and j are from the same or different classes. This is what we calculated
in the kernel above.

Each v,w pair results in its own matrix. I think this makes it harder to
deal with in a single kernel. I think the easiest thing to do is to create
an additional kernel that works like a modified form of matrix
multiplication. Suppose V is the matrix of xvectors and D = NumCols(V).
Then P = V' "times" V is the serialized outer product of each row of v. For
example, P.Row(0) = Serialized( V.Row(0) * V.Row(0)'). In other words,
p_{i,j} = v_{i, (j / D) % D} * v_{i, j % D}.

Once that is done, it should be more straightforward to calculate S_deriv
+= C(i, j) * (P.Row(i) + P.Row(j)) in parallel.


Reply to this email directly or view it on GitHub
https://github.com/danpovey/kaldi/pull/5/files#r52834114.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After that you just have to do the reverse of
the forward operations to get the derivatives w.r.t. S and the matrix of xvectors.

Right, that's what I'm referring to. Once you have the derivs of the objf wrt to the scores (included in C(i,j)), you still need to compute the derivative of the scores wrt to S. However, as far as I can tell, unless you try to do that in a kernel, you'll end up with an algorithm with two loops over the xvectors (see psuedo-code in earlier post). I proposed the kernel above to parallelize that computation.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let me work this out...
The forward computation is something like:

A = X X'
cvec = diag(X S X')
u = vector of ones
S = A - cvec u' - u cvec + b
... compute the objf and get S_deriv which is d(objf)/dS
A_deriv = S_deriv
X_deriv += 2 A_deriv X (or something like that)
cvec_deriv = - sum-of-Sderiv-cols - sum-of-Sderiv-rows
when computing the deriv w.r.t. S I am thinking about the expression
cvec_deriv . cvec,
which equals trace(diag(cvec_deriv) X S X'), where diag(cvec_deriv) is a
matrix whose diagonal is cvec_deriv, which we can rearrange to trace(S (X'
diag(cvec_deriv) X)).
We get from this (through a mysterious process, I do it intuitively),
S_deriv = X' diag(cvec_deriv) X
which is pretty easy to compute.

On Sat, Feb 13, 2016 at 5:09 PM, david-ryan-snyder <notifications@github.com

wrote:

In src/cudamatrix/cu-kernels.cu
#5 (comment):

@@ -2094,6 +2095,26 @@ static void _diff_xent(const int32_cuda* vec_tgt, Real* mat_net_out, Real* vec_l
}
}

+template
+global
+static void _compute_xvector_objf(const Real* scores, MatrixDim scores_dim,

  •                              Real\* obfj_terms, MatrixDim objf_dim,
    
  •                              Real\* obfj_derivs, MatrixDim derivs_dim) {
    
  • int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
  • int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
  • int32_cuda scores_index = i + j * scores_dim.stride;
  • Real K = 1.0 / (scores_dim.rows - 2.0);
  • Real L = scores[scores_index];
  • if (i < scores_dim.cols && j < scores_dim.rows && i < j) {

After that you just have to do the reverse of
the forward operations to get the derivatives w.r.t. S and the matrix of
xvectors.

Right, that's what I'm referring to. Once you have the derivs of the objf
wrt to the scores, you still need to compute the derivative of the scores
wrt to S. However, as far as I can tell, unless you try to do that in a
kernel, you'll end up with an algorithm with two loops over the xvectors. I
proposed the kernel above to parallelize that computation.


Reply to this email directly or view it on GitHub
https://github.com/danpovey/kaldi/pull/5/files#r52834379.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll play with it some more to see if I can get it to work without a kernel and without an O(N^2) computation.

In your procedure, it isn't obvious to me (yet) that you can get terms of the form S_deriv = C(x,y) * (x x' + y y') for all combinations of (x,y) pairs. That's where the O(N^2) comes from that I'm trying to avoid.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that it was possible in thee forward computation generally means
it's possible i the backward computation.
You'll get S_deriv = X' diag(cvec_deriv) X, I think.

On Sat, Feb 13, 2016 at 5:39 PM, david-ryan-snyder <notifications@github.com

wrote:

In src/cudamatrix/cu-kernels.cu
#5 (comment):

@@ -2094,6 +2095,26 @@ static void _diff_xent(const int32_cuda* vec_tgt, Real* mat_net_out, Real* vec_l
}
}

+template
+global
+static void _compute_xvector_objf(const Real* scores, MatrixDim scores_dim,

  •                              Real\* obfj_terms, MatrixDim objf_dim,
    
  •                              Real\* obfj_derivs, MatrixDim derivs_dim) {
    
  • int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
  • int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y;
  • int32_cuda scores_index = i + j * scores_dim.stride;
  • Real K = 1.0 / (scores_dim.rows - 2.0);
  • Real L = scores[scores_index];
  • if (i < scores_dim.cols && j < scores_dim.rows && i < j) {

OK, I'll play with it some more to see if I can get it to work without a
kernel and without an O(N^2) computation.

In your procedure, it isn't obvious to me that you can get terms of the
form S_deriv = x x' + y y' for all combinations of (x,y) pairs. That's
where the O(N^2) comes from that I'm trying to avoid.


Reply to this email directly or view it on GitHub
https://github.com/danpovey/kaldi/pull/5/files#r52834702.

if (i + 1 == j && i % 2 == 0) {
obfj_terms[scores_index] = log(1.0 + exp(-L));
obfj_derivs[scores_index] = 1.0 / (1.0 + exp(L));
} else if (i != j) {
obfj_terms[scores_index] = K * log(1.0 + exp(L));
obfj_derivs[scores_index] = -K / (1.0 + exp(-L));
}
}
}


/***********************************************************************
Expand Down Expand Up @@ -2575,6 +2596,14 @@ void cudaF_equal_element_mask(dim3 Gr, dim3 Bl, const float *mat1,
_equal_element_mask<<<Gr,Bl>>>(mat1, mat2, mask, mat1_dim, mat2_stride, mask_stride);
}

void cudaF_compute_xvector_objf(dim3 Gr, dim3 Bl, const float *scores,
MatrixDim scores_dim, float *objf_terms,
MatrixDim objf_dim, float *objf_derivs,
MatrixDim derivs_dim) {
_compute_xvector_objf<<<Gr,Bl>>>(scores, scores_dim, objf_terms, objf_dim,
objf_derivs, derivs_dim);
}

/*
* "double"
*/
Expand Down Expand Up @@ -3029,6 +3058,13 @@ void cudaD_equal_element_mask(dim3 Gr, dim3 Bl, const double *mat1,
_equal_element_mask<<<Gr,Bl>>>(mat1, mat2, mask, mat1_dim, mat2_stride, mask_stride);
}

void cudaD_compute_xvector_objf(dim3 Gr, dim3 Bl, const double *scores,
MatrixDim scores_dim, double *objf_terms,
MatrixDim objf_dim, double *objf_derivs,
MatrixDim derivs_dim) {
_compute_xvector_objf<<<Gr,Bl>>>(scores, scores_dim, objf_terms, objf_dim,
objf_derivs, derivs_dim);
}


/* Some conversion kernels for which it's more convenient to not name them F or D. */
Expand Down
15 changes: 15 additions & 0 deletions src/cudamatrix/cu-kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ inline void cuda_equal_element_mask(dim3 Gr, dim3 Bl, const float *mat1, const f
cudaF_equal_element_mask(Gr, Bl, mat1, mat2, mask, mat1_dim, mat2_stride, mask_stride);
}

inline void cuda_compute_xvector_objf(dim3 Gr, dim3 Bl, const float *scores,
MatrixDim scores_dim, float *obfj_terms,
MatrixDim objf_dim, float *objf_derivs,
MatrixDim derivs_dim) {
cudaF_compute_xvector_objf(Gr, Bl, scores, scores_dim, obfj_terms, objf_dim,
objf_derivs, derivs_dim);
}


// double versions
Expand Down Expand Up @@ -467,6 +474,14 @@ inline void cuda_equal_element_mask(dim3 Gr, dim3 Bl, const double *mat1, const
cudaD_equal_element_mask(Gr, Bl, mat1, mat2, mask, mat1_dim, mat2_stride, mask_stride);
}

inline void cuda_compute_xvector_objf(dim3 Gr, dim3 Bl, const double *scores,
MatrixDim scores_dim, double *obfj_terms,
MatrixDim objf_dim, double *objf_derivs,
MatrixDim derivs_dim) {
cudaD_compute_xvector_objf(Gr, Bl, scores, scores_dim, obfj_terms, objf_dim,
objf_derivs, derivs_dim);
}

// Also include some template-friendly wrappers of cublas functions:
inline cublasStatus_t cuda_axpy(cublasHandle_t handle, int n, float alpha, const float *x, int incx, float *y, int incy) {
return cublasSaxpy_v2(handle, n, &alpha, x, incx, y, incy);
Expand Down
73 changes: 54 additions & 19 deletions src/cudamatrix/cu-math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

// Copyright 2009-2012 Karel Vesely
// Johns Hopkins University (author: Daniel Povey)
// 2016 David Snyder

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down Expand Up @@ -29,15 +30,15 @@ namespace kaldi {
namespace cu {

/*
* templated functions wrapping the ANSI-C CUDA kernel functions
* templated functions wrapping the ANSI-C CUDA kernel functions
*/


template<typename Real>
void RegularizeL1(CuMatrixBase<Real> *weight, CuMatrixBase<Real> *grad, Real l1, Real lr) {
KALDI_ASSERT(SameDim(*weight, *grad));
#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
Timer tim;

dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
Expand All @@ -46,7 +47,7 @@ void RegularizeL1(CuMatrixBase<Real> *weight, CuMatrixBase<Real> *grad, Real l1,
cuda_regularize_l1(dimGrid, dimBlock, weight->Data(), grad->Data(), l1, lr,
weight->Dim(), grad->Stride());
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
Expand All @@ -55,11 +56,11 @@ void RegularizeL1(CuMatrixBase<Real> *weight, CuMatrixBase<Real> *grad, Real l1,
MatrixBase<Real> &grad2 = grad->Mat();
for(MatrixIndexT r=0; r<weight2.NumRows(); r++) {
for(MatrixIndexT c=0; c<weight2.NumCols(); c++) {

if(weight2(r,c)==0.0) continue; // skip L1 if zero weightght!

Real l1_signed = l1;
if (weight2(r, c) < 0.0)
if (weight2(r, c) < 0.0)
l1_signed = -l1;

Real before = weight2(r, c);
Expand Down Expand Up @@ -88,16 +89,16 @@ void Randomize(const CuMatrixBase<Real> &src,
#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
Timer tim;

/*
Note: default 16x16 block-size limits the --cachesize to matrix size 16*65535 x 16*65535
Note: default 16x16 block-size limits the --cachesize to matrix size 16*65535 x 16*65535
dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
dim3 dimGrid(n_blocks(tgt->NumCols(), CU2DBLOCK), n_blocks(copy_from_idx.Dim(), CU2DBLOCK));
*/

/*
* Let's use blocksize 4 x 128 (512 threads/block)
* and extend the randomizable matrices to: col 4*65535, row 128*65535
* and extend the randomizable matrices to: col 4*65535, row 128*65535
* (ie. max-cols:262140 (dim), max-rows:8388480 (datapoints))
*/
dim3 dimBlock(4, 128);
Expand All @@ -111,7 +112,7 @@ void Randomize(const CuMatrixBase<Real> &src,
cuda_randomize(dimGrid, dimBlock, tgt->Data(), src.Data(),
copy_from_idx.Data(), dimtgt, dimsrc);
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
Expand All @@ -124,28 +125,28 @@ void Randomize(const CuMatrixBase<Real> &src,
tgtmat.Row(i).CopyFromVec(srcmat.Row(copy_from_idxvec[i]));
}
}
}
}



template<typename Real>
void Splice(const CuMatrixBase<Real> &src, const CuArray<int32> &frame_offsets,
CuMatrixBase<Real> *tgt) {

KALDI_ASSERT(src.NumCols()*frame_offsets.Dim() == tgt->NumCols());
KALDI_ASSERT(src.NumRows() == tgt->NumRows());

#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
Timer tim;

dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
dim3 dimGrid(n_blocks(tgt->NumCols(), CU2DBLOCK), n_blocks(tgt->NumRows(), CU2DBLOCK));

cuda_splice(dimGrid, dimBlock, tgt->Data(), src.Data(),
frame_offsets.Data(), tgt->Dim(), src.Dim());
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
Expand All @@ -171,22 +172,22 @@ void Splice(const CuMatrixBase<Real> &src, const CuArray<int32> &frame_offsets,

template<typename Real>
void Copy(const CuMatrixBase<Real> &src, const CuArray<int32> &copy_from_indices,
CuMatrixBase<Real> *tgt) {
CuMatrixBase<Real> *tgt) {

KALDI_ASSERT(copy_from_indices.Dim() == tgt->NumCols());
KALDI_ASSERT(src.NumRows() == tgt->NumRows());

#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
Timer tim;

dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
dim3 dimGrid(n_blocks(tgt->NumCols(), CU2DBLOCK), n_blocks(tgt->NumRows(), CU2DBLOCK));

cuda_copy(dimGrid, dimBlock, tgt->Data(), src.Data(),
copy_from_indices.Data(), tgt->Dim(), src.Dim());
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
Expand All @@ -205,6 +206,31 @@ void Copy(const CuMatrixBase<Real> &src, const CuArray<int32> &copy_from_indices
}
}

template<typename Real>
void ComputeXvectorObjfFromScores(const CuMatrixBase<Real> &scores,
CuMatrixBase<Real> *objf_terms,
CuMatrixBase<Real> *objf_derivs) {
#if HAVE_CUDA == 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the dimensions at the beginning of this function-- KALDI_ASSERT(SameDim(scores, *objf_terms) && ..)

if (CuDevice::Instantiate().Enabled()) {
Timer tim;
dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
dim3 dimGrid(n_blocks(scores.NumCols(), CU2DBLOCK),
n_blocks(scores.NumRows(), CU2DBLOCK));

cuda_compute_xvector_objf(dimGrid, dimBlock, scores.Data(), scores.Dim(),
objf_terms->Data(), objf_terms->Dim(), objf_derivs->Data(),
objf_derivs->Dim());
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
{
// TODO: Add the CPU version.
KALDI_LOG << "NOT USING CUDA";
}
}

// instantiate the templates.
template
void RegularizeL1(CuMatrixBase<float> *weight, CuMatrixBase<float> *grad, float l1, float lr);
Expand Down Expand Up @@ -233,6 +259,15 @@ void Randomize(const CuMatrixBase<double> &src,
const CuArray<int32> &copy_from_idx,
CuMatrixBase<double> *tgt);

template
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is kind of a special purpose function, you don't have to instantiate for both float and double-- you can just hard-code it to BaseFloat. [i.e. not a template]

void ComputeXvectorObjfFromScores(const CuMatrixBase<float> &scores,
CuMatrixBase<float> *objf_terms,
CuMatrixBase<float> *objf_derivs);
template
void ComputeXvectorObjfFromScores(const CuMatrixBase<double> &scores,
CuMatrixBase<double> *objf_terms,
CuMatrixBase<double> *objf_derivs);



} //namespace cu
Expand Down
11 changes: 9 additions & 2 deletions src/cudamatrix/cu-math.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// cudamatrix/cu-math.h

// Copyright 2009-2012 Karel Vesely
// 2013 Johns Hopkins University (Author: David Snyder)
// 2013 Johns Hopkins University (Author: Daniel Povey)
// 2016 David Snyder

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down Expand Up @@ -78,7 +79,13 @@ void Group2norm(const CuMatrixBase<Real> &src,
CuMatrixBase<Real> *dest,
int32 group_stride);


/*
TODO: Documentation.
*/
template <typename BaseFloat>
void ComputeXvectorObjfFromScores(const CuMatrixBase<BaseFloat> &scores,
CuMatrixBase<BaseFloat> *objf_terms,
CuMatrixBase<BaseFloat> *objf_derivs);


} // namespace cu
Expand Down
12 changes: 8 additions & 4 deletions src/ivector/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@ OPENFST_CXXFLAGS =
OPENFST_LDLIBS =
include ../kaldi.mk

TESTFILES = ivector-extractor-test plda-test logistic-regression-test
LDFLAGS += $(CUDA_LDFLAGS)
LDLIBS += $(CUDA_LDLIBS)

OBJFILES = ivector-extractor.o voice-activity-detection.o plda.o logistic-regression.o
TESTFILES = ivector-extractor-test plda-test logistic-regression-test xvector-test

OBJFILES = ivector-extractor.o voice-activity-detection.o plda.o logistic-regression.o xvector.o

LIBNAME = kaldi-ivector

ADDLIBS = ../gmm/kaldi-gmm.a ../tree/kaldi-tree.a ../transform/kaldi-transform.a \
../thread/kaldi-thread.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a \
../util/kaldi-util.a
../thread/kaldi-thread.a ../cudamatrix/kaldi-cudamatrix.a \
../matrix/kaldi-matrix.a ../base/kaldi-base.a \
../util/kaldi-util.a

include ../makefiles/default_rules.mk
Loading