diff --git a/src/cudafeat/Makefile b/src/cudafeat/Makefile index 913c1ea9dbb..dff0dd63174 100644 --- a/src/cudafeat/Makefile +++ b/src/cudafeat/Makefile @@ -8,14 +8,16 @@ ifeq ($(CUDA), true) TESTFILES = ifeq ($(CUDA), true) - OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o feature-online-cmvn-cuda.o + OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o feature-online-cmvn-cuda.o \ + online-ivector-feature-cuda-kernels.o online-ivector-feature-cuda.o \ + online-cuda-feature-pipeline.o endif LIBNAME = kaldi-cudafeat ADDLIBS = ../feat/kaldi-feat.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ ../base/kaldi-base.a ../cudamatrix/kaldi-cudamatrix.a \ - ../gmm/kaldi-gmm.a ../online2/kaldi-online2.a + ../gmm/kaldi-gmm.a ../ivector/kaldi-ivector.a ../online2/kaldi-online2.a LDFLAGS += $(CUDA_LDFLAGS) LDLIBS += $(CUDA_LDLIBS) diff --git a/src/cudafeat/feature-mfcc-cuda.cu b/src/cudafeat/feature-mfcc-cuda.cu index 730e7bd47e7..cd7347601af 100644 --- a/src/cudafeat/feature-mfcc-cuda.cu +++ b/src/cudafeat/feature-mfcc-cuda.cu @@ -14,8 +14,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +#if HAVE_CUDA == 1 #include #include +#endif #include "cudafeat/feature-mfcc-cuda.h" #include "cudamatrix/cu-rand.h" diff --git a/src/cudafeat/feature-window-cuda.cu b/src/cudafeat/feature-window-cuda.cu index 0c98bee30ba..7ce7d798ca2 100644 --- a/src/cudafeat/feature-window-cuda.cu +++ b/src/cudafeat/feature-window-cuda.cu @@ -15,7 +15,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#if HAVE_CUDA == 1 #include +#endif #include "cudafeat/feature-window-cuda.h" #include "matrix/matrix-functions.h" diff --git a/src/cudafeat/online-cuda-feature-pipeline.cc b/src/cudafeat/online-cuda-feature-pipeline.cc new file mode 100644 index 00000000000..4fd092b4f05 --- /dev/null +++ b/src/cudafeat/online-cuda-feature-pipeline.cc @@ -0,0 +1,70 @@ +// cudafeat/online-cuda-feature-pipleine.cc + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "cudafeat/online-cuda-feature-pipeline.h" + +namespace kaldi { + +OnlineCudaFeaturePipeline::OnlineCudaFeaturePipeline( + const OnlineNnet2FeaturePipelineConfig &config) + : info_(config), mfcc(NULL), ivector(NULL) { + if (info_.feature_type == "mfcc") { + mfcc = new CudaMfcc(info_.mfcc_opts); + } + + if (info_.use_ivectors) { + OnlineIvectorExtractionConfig ivector_extraction_opts; + ReadConfigFromFile(config.ivector_extraction_config, + &ivector_extraction_opts); + info_.ivector_extractor_info.Init(ivector_extraction_opts); + + // Only these ivector options are currently supported + ivector_extraction_opts.use_most_recent_ivector = true; + ivector_extraction_opts.greedy_ivector_extractor = true; + + ivector = new IvectorExtractorFastCuda(ivector_extraction_opts); + } +} + +OnlineCudaFeaturePipeline::~OnlineCudaFeaturePipeline() { + if (mfcc != NULL) delete mfcc; + if (ivector != NULL) delete ivector; +} + +void OnlineCudaFeaturePipeline::ComputeFeatures( + const CuVectorBase &cu_wave, BaseFloat sample_freq, + CuMatrix *input_features, + CuVector *ivector_features) { + if (info_.feature_type == "mfcc") { + // MFCC + float vtln_warp = 1.0; + mfcc->ComputeFeatures(cu_wave, sample_freq, vtln_warp, input_features); + } else { + KALDI_ASSERT(false); + } + + // Ivector + if (info_.use_ivectors && ivector_features != NULL) { + ivector->GetIvector(*input_features, ivector_features); + } else { + KALDI_ASSERT(false); + } +} + +} // namespace kaldi diff --git a/src/cudafeat/online-cuda-feature-pipeline.h b/src/cudafeat/online-cuda-feature-pipeline.h new file mode 100644 index 00000000000..5c71d37b395 --- /dev/null +++ b/src/cudafeat/online-cuda-feature-pipeline.h @@ -0,0 +1,55 @@ +// cudafeat/online-cuda-feature-pipeline.h + +// Copyright 2013-2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_CUDAFEAT_ONLINE_CUDA_FEATURE_PIPELINE_H_ +#define KALDI_CUDAFEAT_ONLINE_CUDA_FEATURE_PIPELINE_H_ + +#include +#include +#include + +#include "base/kaldi-error.h" +#include "cudafeat/feature-mfcc-cuda.h" +#include "cudafeat/online-ivector-feature-cuda.h" +#include "matrix/matrix-lib.h" +#include "online2/online-nnet2-feature-pipeline.h" +#include "util/common-utils.h" + +namespace kaldi { + +class OnlineCudaFeaturePipeline { + public: + explicit OnlineCudaFeaturePipeline( + const OnlineNnet2FeaturePipelineConfig &config); + + void ComputeFeatures(const CuVectorBase &cu_wave, + BaseFloat sample_freq, + CuMatrix *input_features, + CuVector *ivector_features); + + ~OnlineCudaFeaturePipeline(); + + private: + OnlineNnet2FeaturePipelineInfo info_; + CudaMfcc *mfcc; + IvectorExtractorFastCuda *ivector; +}; +} // namespace kaldi + +#endif // KALDI_CUDAFEAT_ONLINE_CUDA_FEATURE_EXTRACTOR_H_ diff --git a/src/cudafeat/online-ivector-feature-cuda-kernels.cu b/src/cudafeat/online-ivector-feature-cuda-kernels.cu new file mode 100644 index 00000000000..227f49deb63 --- /dev/null +++ b/src/cudafeat/online-ivector-feature-cuda-kernels.cu @@ -0,0 +1,239 @@ +// cudafeat/online-ivector-feature-cuda-kernels.cu +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cudafeat/online-ivector-feature-cuda-kernels.h" +#include "cudamatrix/cu-common.h" +namespace kaldi { + +// Meant to be called with blockDim= 32x32 +__global__ void batched_gemv_reduce_kernel(int rows, int cols, + const float* __restrict__ A, int lda, + const float* __restrict__ X, int ldx, + float* C) { + // Specialize WarpReduce for type float + typedef cub::WarpReduce WarpReduce; + // Allocate WarpReduce shared memory for 32 warps + __shared__ typename WarpReduce::TempStorage temp_storage[32]; + + __shared__ float s_A[32][32 + 1]; //+1 to avoid bank conflicts on transpose + + int bid = blockIdx.x; // batch id + int tid = threadIdx.x; // thread id + int wid = threadIdx.y; // warp id + + // Offset to input matrix to starting row for batch + const float* __restrict__ A_in = A + bid * rows * lda; + // Offset to input vector to starting column for batch + const float* __restrict__ X_in = X + bid * ldx; + + for (int i = 0; i < cols; i += 32) { // threadIdx.x, keep all threads present + int c = i + tid; + + float sum = 0.0f; + // Perform dot product + for (int j = 0; j < rows; + j += 32) { // threadIdx.y, keep all threads present + int r = j + wid; + + float val = 0.0f; + if (c < cols && r < rows) { + // coalesced reads + val = A_in[r * lda + c] * X_in[r]; + } + + // write to shared memory + __syncthreads(); // wait for shared memory to be written + s_A[wid][tid] = val; + __syncthreads(); // wait for shared memory to be consumed + + // transpose read from shared memory and collect sum + sum += s_A[tid][wid]; + } + // reduce sum in cub + sum = WarpReduce(temp_storage[wid]).Sum(sum); + + // update c now that we are trasnposed + c = i + wid; + if (tid == 0 && c < cols) { + // Add contribution to final sum. + // Atomic necessary due to different batches updating this + atomicAdd(&C[c], sum); + } + } +} + +// computes feats^2. This works in place and out of place. +__global__ void square_matrix_kernel(int32_t num_rows, int32_t num_cols, + const float* feats, int32_t ldf, + float* feats_sq, int32_t lds) { + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_rows; + i += blockDim.y * gridDim.y) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_cols; + j += blockDim.x * gridDim.x) { + float f = feats[i * ldf + j]; + feats_sq[i * lds + j] = f * f; + } + } +} + +// takes features in feat and writes them into sfeats while applying +// the splicing algorithm for the left and right context. +// input features that are out of range are clamped. +__global__ void splice_features_kernel(int32_t num_frames, int32_t feat_dim, + int32_t left, int32_t size, + const float* __restrict__ feats, + int32_t ldf, float* __restrict__ sfeats, + int32_t lds) { + int32_t frame = blockIdx.x; + int32_t tid = threadIdx.x; + + // offset feature output to process frame + float* feat_out = sfeats + lds * frame; + + // for each splice of input + for (int i = 0; i < size; i++) { + int r = frame + i + left; + // clamp input row + if (r < 0) r = 0; + if (r > num_frames - 1) r = num_frames - 1; + + // for each column of input in parallel + for (int c = tid; c < feat_dim; c += blockDim.x) { + // read feature from input row offset by column + float val = feats[r * ldf + c]; + + // write feature to output offset by splice index and column + feat_out[i * feat_dim + c] = val; + } + } +} + +// Computes the sum of all terms in a matrix. +// The kernel double buffers the output such that the +// output is written to retval[b] where b is 0 or 1. +// The output element of retval is written as zero. +// Double buffering eliminates a call to cudaMemset +__global__ void get_matrix_sum_double_buffer_kernel(int32_t b, int32_t num_rows, + int32_t num_cols, float* A, + int32_t lda, float scale, + float* retval) { + // Specialize WarpReduce for type float + typedef cub::BlockReduce + BlockReduce; + // Allocate WarpReduce shared memory for 32 warps + __shared__ typename BlockReduce::TempStorage temp_storage; + + float sum = 0.0f; + + // compute local sums in parallel + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_rows; + i += blockDim.y * gridDim.y) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_cols; + j += blockDim.x * gridDim.x) { + sum += A[i * lda + j]; + } + } + + sum = BlockReduce(temp_storage).Sum(sum); + + if (threadIdx.x == 0 && threadIdx.y == 0) { + atomicAdd(&retval[b], sum * scale); + int next_b = (b + 1) % 2; + retval[next_b] = 0.0f; + } +} + +// This kernel updates the linear and quadradic terms. +// It does not support a previous weight coming in and would need to be updated +// for online decoding. +__global__ void update_linear_and_quadratic_terms_kernel( + int32_t n, float prior_offset, float* cur_tot_weight, int32_t max_count, + float* quadratic, float* linear) { + float val = 1.0f; + float cur_weight = *cur_tot_weight; + + if (max_count > 0.0f) { + float new_scale = max((float)cur_weight, (float)max_count) / max_count; + + float prior_scale_change = new_scale - 1.0f; + val += prior_scale_change; + } + + for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + int32_t diag_idx = ((i + 1) * (i + 2) / 2) - 1; + quadratic[diag_idx] += val; + } + + if (threadIdx.x == 0) { + linear[0] += val * prior_offset; + } +} + +void batched_gemv_reduce(int batch_size, int rows, int cols, int A_stride, + const float* AT, int B_stride, const float* B, + const float* y, float* C) { + batched_gemv_reduce_kernel<<>>( + rows, cols, AT, A_stride, B, B_stride, C); + CU_SAFE_CALL(cudaGetLastError()); +} + +void splice_features(int32_t num_frames, int32_t feat_dim, int32_t left, + int32_t size, const float* feats, int32_t ldf, + float* sfeats, int32_t lds) { + int threads = (feat_dim + 31) / 32 * 32; // round up to the nearest warp size + if (threads > 1024) threads = 1024; // Max block size is 1024 threads + + splice_features_kernel<<>>( + num_frames, feat_dim, left, size, feats, ldf, sfeats, lds); + CU_SAFE_CALL(cudaGetLastError()); +} + +void update_linear_and_quadratic_terms(int32_t n, float prior_offset, + float* cur_tot_weight, int32_t max_count, + float* quadratic, float* linear) { + // Only using 1 CTA here for now as the updates are tiny and this lets us + // use syncthreads as a global barrier. + update_linear_and_quadratic_terms_kernel<<<1, 1024>>>( + n, prior_offset, cur_tot_weight, max_count, quadratic, linear); + CU_SAFE_CALL(cudaGetLastError()); +} + +void get_matrix_sum_double_buffer(int32_t b, int32_t num_rows, int32_t num_cols, + float* A, int32_t lda, float scale, + float* sum) { + dim3 threads(32, 32); + dim3 blocks((num_cols + threads.x - 1) / threads.x, + (num_rows + threads.y - 1) / threads.y); + + get_matrix_sum_double_buffer_kernel<<>>( + b, num_rows, num_cols, A, lda, scale, sum); + CU_SAFE_CALL(cudaGetLastError()); +} + +void square_matrix(int32_t num_rows, int32_t num_cols, const float* feats, + int32_t ldf, float* feats_sq, int32_t lds) { + dim3 threads(32, 32); + dim3 blocks((num_cols + threads.x - 1) / threads.x, + (num_rows + threads.y - 1) / threads.y); + + square_matrix_kernel<<>>(num_rows, num_cols, feats, ldf, + feats_sq, lds); + CU_SAFE_CALL(cudaGetLastError()); +} +} diff --git a/src/cudafeat/online-ivector-feature-cuda-kernels.h b/src/cudafeat/online-ivector-feature-cuda-kernels.h new file mode 100644 index 00000000000..62407b77b2b --- /dev/null +++ b/src/cudafeat/online-ivector-feature-cuda-kernels.h @@ -0,0 +1,40 @@ +// cudafeat/online-ivector-feature-cuda-kernels.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef CUDAFEAT_ONLINE_IVECTOR_FEATURE_CUDA_KERNELS +#define CUDAFEAT_ONLINE_IVECTOR_FEATURE_CUDA_KERNELS + +namespace kaldi { +void batched_gemv_reduce(int batch_size, int rows, int cols, int A_stride, + const float *AT, int B_stride, const float *B, + const float *y, float *C); + +void splice_features(int32_t num_frames, int32_t feat_dim, int32_t left, + int32_t size, const float *feats, int32_t ldf, + float *sfeats, int32_t lds); + +void update_linear_and_quadratic_terms(int32_t n, float prior_offset_, + float *cur_tot_weight, int32_t max_count, + float *quadratic, float *linear); + +void get_matrix_sum_double_buffer(int32_t b, int32_t num_rows, int32_t num_cols, + float *A, int32_t lda, float scale, + float *sum); + +void square_matrix(int32_t num_rows, int32_t num_cols, const float *feats, + int32_t ldf, float *feats_sq, int32_t lds); +} +#endif diff --git a/src/cudafeat/online-ivector-feature-cuda.cc b/src/cudafeat/online-ivector-feature-cuda.cc new file mode 100644 index 00000000000..192e8d25686 --- /dev/null +++ b/src/cudafeat/online-ivector-feature-cuda.cc @@ -0,0 +1,278 @@ +// cudafeat/online-ivector-feature-cuda.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 +#include +#endif +#include + +#include "base/io-funcs.h" +#include "base/kaldi-common.h" +#include "base/timer.h" +#include "cudafeat/feature-online-cmvn-cuda.h" +#include "cudafeat/online-ivector-feature-cuda-kernels.h" +#include "cudafeat/online-ivector-feature-cuda.h" +#include "cudamatrix/cu-device.h" +#include "cudamatrix/cu-sp-matrix.h" +#include "gmm/diag-gmm.h" +#include "util/kaldi-io.h" +#include "util/table-types.h" +namespace kaldi { + +void IvectorExtractorFastCuda::GetIvector(const CuMatrixBase &feats, + CuVector *ivector) { + nvtxRangePushA("GetIvector"); + CuMatrix posteriors, X; + CuVector gamma; + + // normalized pipeline + CuMatrix lda_feats_normalized(feats.NumRows(), feats.NumCols(), + kUndefined); + { + CudaOnlineCmvn cmvn(info_.cmvn_opts, naive_cmvn_state_); + CuMatrix cmvn_feats(feats.NumRows(), feats.NumCols(), + kUndefined); + CuMatrix spliced_feats_normalized; + + // Normalize + cmvn.ComputeFeatures(feats, &cmvn_feats); + + // Splice + SpliceFeats(cmvn_feats, &spliced_feats_normalized); + + // Transform by LDA matrix + lda_feats_normalized.AddMatMat(1.0, spliced_feats_normalized, kNoTrans, + cu_lda_, kTrans, 0.0); + } + + // non-normalized pipeline + CuMatrix lda_feats(feats.NumRows(), feats.NumCols(), kUndefined); + { + CuMatrix spliced_feats; + + // Splice feats + SpliceFeats(feats, &spliced_feats); + + // Transform by LDA matrix + lda_feats.AddMatMat(1.0, spliced_feats, kNoTrans, cu_lda_, kTrans, 0.0); + } + + // based on normalized feats + ComputePosteriors(lda_feats_normalized, &posteriors); + + // based on non-normalized feats + ComputeIvectorStats(lda_feats, posteriors, &gamma, &X); + + ComputeIvectorFromStats(gamma, X, ivector); + + nvtxRangePop(); +} + +void IvectorExtractorFastCuda::Read( + const kaldi::OnlineIvectorExtractionConfig &config) { + // read ubm + DiagGmm gmm; + ReadKaldiObject(config.diag_ubm_rxfilename, &gmm); + ubm_gconsts_.Resize(gmm.NumGauss()); + ubm_gconsts_.CopyFromVec(gmm.gconsts()); + ubm_means_inv_vars_.Resize(gmm.NumGauss(), gmm.Dim()); + ubm_means_inv_vars_.CopyFromMat(gmm.means_invvars()); + ubm_inv_vars_.Resize(gmm.NumGauss(), gmm.Dim()); + ubm_inv_vars_.CopyFromMat(gmm.inv_vars()); + num_gauss_ = gmm.NumGauss(); + + // read extractor (copied from ivector/ivector-extractor.cc) + bool binary; + Input input(config.ivector_extractor_rxfilename, &binary); + Matrix w; + Vector w_vec; + std::vector > ie_M; + std::vector > ie_Sigma_inv; + + ExpectToken(input.Stream(), binary, ""); + ExpectToken(input.Stream(), binary, ""); + w.Read(input.Stream(), binary); + ExpectToken(input.Stream(), binary, ""); + w_vec.Read(input.Stream(), binary); + ExpectToken(input.Stream(), binary, ""); + int32 size; + ReadBasicType(input.Stream(), binary, &size); + KALDI_ASSERT(size > 0); + ie_M.resize(size); + for (int32 i = 0; i < size; i++) { + ie_M[i].Read(input.Stream(), binary); + } + ExpectToken(input.Stream(), binary, ""); + ie_Sigma_inv.resize(size); + for (int32 i = 0; i < size; i++) { + ie_Sigma_inv[i].Read(input.Stream(), binary); + } + ExpectToken(input.Stream(), binary, ""); + ReadBasicType(input.Stream(), binary, &prior_offset_); + ExpectToken(input.Stream(), binary, ""); + + // compute derived variables + ivector_dim_ = ie_M[0].NumCols(); + feat_dim_ = ie_M[0].NumRows(); + + ie_Sigma_inv_M_f_.Resize(num_gauss_ * feat_dim_, ivector_dim_); + + ie_U_.Resize(num_gauss_, ivector_dim_ * (ivector_dim_ + 1) / 2); + + SpMatrix tmp_sub_U(ivector_dim_); + Matrix tmp_Sigma_inv_M(feat_dim_, ivector_dim_); + for (int32 i = 0; i < num_gauss_; i++) { + // compute matrix ie_Sigma_inv_M[i[ + tmp_sub_U.AddMat2Sp(1, ie_M[i], kTrans, ie_Sigma_inv[i], 0); + SubVector tmp_U_vec(tmp_sub_U.Data(), + ivector_dim_ * (ivector_dim_ + 1) / 2); + ie_U_.Row(i).CopyFromVec(tmp_U_vec); + + tmp_Sigma_inv_M.AddSpMat(1, ie_Sigma_inv[i], ie_M[i], kNoTrans, 0); + + // copy into global matrix + CuSubMatrix window(ie_Sigma_inv_M_f_, i * feat_dim_, feat_dim_, 0, + ivector_dim_); + window.CopyFromMat(tmp_Sigma_inv_M); + } +} + +void IvectorExtractorFastCuda::SpliceFeats(const CuMatrixBase &feats, + CuMatrix *spliced_feats) { + int left = -info_.splice_opts.left_context; + int right = info_.splice_opts.right_context; + int size = right - left + 1; + spliced_feats->Resize(feats.NumRows(), feats.NumCols() * size, kUndefined); + + splice_features(feats.NumRows(), feats.NumCols(), left, size, feats.Data(), + feats.Stride(), spliced_feats->Data(), + spliced_feats->Stride()); +} + +void IvectorExtractorFastCuda::ComputePosteriors( + const CuMatrixBase &feats, CuMatrix *posteriors) { + int num_frames = feats.NumRows(); + + posteriors->Resize(num_frames, num_gauss_, kUndefined); + + posteriors->CopyRowsFromVec(ubm_gconsts_); + + CuMatrix feats_sq(feats.NumRows(), feats.NumCols(), kUndefined); + + // using our own kernel here to avoid an extra memcpy. + // ApplyPow unfortunately only works in place. + square_matrix(feats.NumRows(), feats.NumCols(), feats.Data(), feats.Stride(), + feats_sq.Data(), feats_sq.Stride()); + + posteriors->AddMatMat(1.0, feats, kNoTrans, ubm_means_inv_vars_, kTrans, 1.0); + posteriors->AddMatMat(-0.5, feats_sq, kNoTrans, ubm_inv_vars_, kTrans, 1.0); + + // apply scaling factor + posteriors->ApplySoftMaxPerRow(*posteriors); + + if (info_.max_count > 0) { + // when max count > 0 we need to know the total posterior sum to adjust + // the prior offset. So calculate that here. + get_matrix_sum_double_buffer( + b_, posteriors->NumRows(), posteriors->NumCols(), posteriors->Data(), + posteriors->Stride(), info_.posterior_scale, tot_post_.Data()); + } +} + +void IvectorExtractorFastCuda::ComputeIvectorStats( + const CuMatrixBase &feats, const CuMatrixBase &posteriors, + CuVector *gamma, CuMatrix *X) { + gamma->Resize(num_gauss_, kUndefined); + X->Resize(num_gauss_, feat_dim_, kUndefined); + + gamma->AddRowSumMat(info_.posterior_scale, posteriors, 0.0f); + X->AddMatMat(info_.posterior_scale, posteriors, kTrans, feats, kNoTrans, + 0.0f); +} + +void IvectorExtractorFastCuda::ComputeIvectorFromStats( + const CuVector &gamma, const CuMatrix &X, + CuVector *ivector) { + CuVector &linear = *ivector; + linear.Resize(ivector_dim_, kUndefined); + // Initialize to zero as batched kernel is += + linear.SetZero(); + + CuSpMatrix quadratic(ivector_dim_, kUndefined); + + batched_gemv_reduce(num_gauss_, feat_dim_, ivector_dim_, + ie_Sigma_inv_M_f_.Stride(), ie_Sigma_inv_M_f_.Data(), + X.Stride(), X.Data(), gamma.Data(), linear.Data()); + + CuSubVector q_vec(quadratic.Data(), + ivector_dim_ * (ivector_dim_ + 1) / 2); + q_vec.AddMatVec(1.0f, ie_U_, kTrans, gamma, 0.0f); + + // compute and apply prior offset to linear and quadraditic terms + // offset tot_post_ by correct buffer + update_linear_and_quadratic_terms(quadratic.NumRows(), prior_offset_, + tot_post_.Data() + b_, info_.max_count, + quadratic.Data(), linear.Data()); + // advance double buffer + b_ = (b_ + 1) % 2; + + // We are computing a solution to this linear system: + // x = quadratic^-1 * linear + // ivector+=x + + // Inverting the matrix is unneccessary. We are only solving a single + // linear system. So just use choleskey's to solve for a single ivector + // Equation being solved: quadratic * ivector = linear + + int nrhs = 1; + + // Forming new non-SP matrix for cusolver. + CuMatrix A(quadratic); + + // This is the cusolver return code. Checking it would require + // synchronization. + // So we do not check it. + int *d_info = NULL; + + // query temp buffer size + int L_work; + CUSOLVER_SAFE_CALL( + cusolverDnSpotrf_bufferSize(GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, + ivector_dim_, A.Data(), A.Stride(), &L_work)); + + // allocate temp buffer + float *workspace = + static_cast(CuDevice::Instantiate().Malloc(L_work)); + + // perform factorization + CUSOLVER_SAFE_CALL(cusolverDnSpotrf( + GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, ivector_dim_, A.Data(), + A.Stride(), workspace, L_work, d_info)); + + // solve for rhs + CUSOLVER_SAFE_CALL(cusolverDnSpotrs( + GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, ivector_dim_, nrhs, + A.Data(), A.Stride(), ivector->Data(), ivector_dim_, d_info)); + + CuDevice::Instantiate().Free(workspace); + + // remove prior + CuSubVector ivector0(*ivector, 0, 1); + ivector0.Add(-prior_offset_); +} + +}; // namespace kaldi diff --git a/src/cudafeat/online-ivector-feature-cuda.h b/src/cudafeat/online-ivector-feature-cuda.h new file mode 100644 index 00000000000..b661521f782 --- /dev/null +++ b/src/cudafeat/online-ivector-feature-cuda.h @@ -0,0 +1,123 @@ +// cudafeat/online-ivector-feature-cuda.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef CUDAFEAT_ONLINE_IVECTOR_FEATURE_CUDA_H_ +#define CUDAFEAT_ONLINE_IVECTOR_FEATURE_CUDA_H_ + +#include +#include + +#include "base/kaldi-error.h" +#include "cudafeat/feature-online-cmvn-cuda.h" +#include "cudamatrix/cu-matrix.h" +#include "online2/online-ivector-feature.h" + +namespace kaldi { + +class IvectorExtractorFastCuda { + public: + IvectorExtractorFastCuda(const OnlineIvectorExtractionConfig &config) + : b_(0), tot_post_(2) { + if (config.use_most_recent_ivector == false) { + KALDI_WARN + << "IvectorExractorFastCuda: Ignoring use_most_recent_ivector=false."; + } + if (config.greedy_ivector_extractor == false) { + KALDI_WARN << "IvectorExractorFastCuda: Ignoring " + "greedy_ivector_extractor=false."; + } + + info_.Init(config); + naive_cmvn_state_ = OnlineCmvnState(info_.global_cmvn_stats); + Read(config); + cu_lda_.Resize(info_.lda_mat.NumRows(), info_.lda_mat.NumCols()); + cu_lda_.CopyFromMat(info_.lda_mat); + } + ~IvectorExtractorFastCuda() {} + + // This function goes directly from features to an i-vector + // which makes the computation easier to port to GPU + // and make it run more efficiently + // + // It is roughly the replacement for the following in kaldi: + // + // DiagGmm.LogLikelihoods(), VectorToPosteriorEntry() + // IvectorExtractorUtteranceStats.AccStats() + // IvectorExtractor.GetIvectorDistribution() + // + // Also note we only do single precision (float) + // which will *NOT* give same results as kaldi + // i-vector extractor which is float precision + // however, in practice, the differences do *NOT* + // affect overall accuracy + // + // This function is thread safe as all class variables + // are read-only + // + void GetIvector(const CuMatrixBase &feats, CuVector *ivector); + + int32 FeatDim() const { return feat_dim_; } + int32 IvectorDim() const { return ivector_dim_; } + int32 NumGauss() const { return num_gauss_; } + + private: + OnlineIvectorExtractionInfo info_; + + IvectorExtractorFastCuda(IvectorExtractorFastCuda const &); + IvectorExtractorFastCuda &operator=(IvectorExtractorFastCuda const &); + + void Read(const kaldi::OnlineIvectorExtractionConfig &config); + + void SpliceFeats(const CuMatrixBase &feats, + CuMatrix *spliced_feats); + + void ComputePosteriors(const CuMatrixBase &feats, + CuMatrix *posteriors); + + void ComputeIvectorStats(const CuMatrixBase &feats, + const CuMatrixBase &posteriors, + CuVector *gamma, CuMatrix *X); + + void ComputeIvectorFromStats(const CuVector &gamma, + const CuMatrix &X, + CuVector *ivector); + + CudaOnlineCmvnState naive_cmvn_state_; + + int32 feat_dim_; + int32 ivector_dim_; + int32 num_gauss_; + + // ubm variables + CuVector ubm_gconsts_; + CuMatrix ubm_means_inv_vars_; + CuMatrix ubm_inv_vars_; + CuMatrix cu_lda_; + // extractor variables + CuMatrix ie_U_; + + // Batched matrix which sotres this: + CuMatrix ie_Sigma_inv_M_f_; + + // double buffer to store total posteriors. + // double buffering avoids extra calls to intitialize buffer + int b_; + CuVector tot_post_; + float prior_offset_; +}; +} // namespace kaldi + +#endif // IVECTOR_IVECTOR_EXTRACTOR_FAST_CUDA_H_ diff --git a/src/cudafeatbin/Makefile b/src/cudafeatbin/Makefile index e1af458b62e..5b392597d4a 100644 --- a/src/cudafeatbin/Makefile +++ b/src/cudafeatbin/Makefile @@ -9,7 +9,7 @@ LDLIBS += $(CUDA_LDLIBS) BINFILES = ifeq ($(CUDA), true) - BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda + BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda compute-online-feats-cuda endif OBJFILES = diff --git a/src/cudafeatbin/compute-online-feats-cuda.cc b/src/cudafeatbin/compute-online-feats-cuda.cc new file mode 100644 index 00000000000..b9135c3cee6 --- /dev/null +++ b/src/cudafeatbin/compute-online-feats-cuda.cc @@ -0,0 +1,123 @@ +// cudafeatbin/compute-online-feats-cuda.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 +#include +#endif +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "cudafeat/online-cuda-feature-pipeline.h" +#include "feat/wave-reader.h" +#include "cudamatrix/cu-matrix.h" +#include "cudamatrix/cu-vector.h" + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + try { + const char *usage = + "Extract features and ivectors for utterances using the cuda online\n" + "feature pipeline. This class models the online feature pipeline.\n" + "\n" + "Usage: compute-online-feats-cuda [options] " + " \n" + "e.g.: \n" + " ./compute-online-feats-cuda --config=feature_config wav.scp " + "ark,scp:ivector.ark,ivector.scp ark,scp:feat.ark,feat.scp\n"; + + ParseOptions po(usage); + // Use online feature config as that is the flow we are trying to model + OnlineNnet2FeaturePipelineConfig feature_opts; + + feature_opts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + g_cuda_allocator.SetOptions(g_allocator_options); + CuDevice::Instantiate().SelectGpuId("yes"); + CuDevice::Instantiate().AllowMultithreading(); + + std::string wav_rspecifier = po.GetArg(1), + ivector_wspecifier = po.GetArg(2), + feature_wspecifier = po.GetArg(3); + + OnlineCudaFeaturePipeline feature_pipeline(feature_opts); + + SequentialTableReader reader(wav_rspecifier); + BaseFloatVectorWriter ivector_writer; + BaseFloatMatrixWriter feature_writer; + + if (!ivector_writer.Open(ivector_wspecifier)) { + KALDI_ERR << "Could not initialize ivector_writer with wspecifier " + << ivector_wspecifier; + } + if (!feature_writer.Open(feature_wspecifier)) { + KALDI_ERR << "Could not initialize feature_writer with wspecifier " + << feature_wspecifier; + } + + int32 num_utts = 0, num_success = 0; + for (; !reader.Done(); reader.Next()) { + num_utts++; + std::string utt = reader.Key(); + KALDI_LOG << "Processing Utterance " << utt; + try + { + const WaveData &wave_data = reader.Value(); + SubVector waveform(wave_data.Data(), 0); + CuVector cu_wave(waveform); + CuMatrix cu_features; + CuVector cu_ivector; + + nvtxRangePushA("Feature Extract"); + feature_pipeline.ComputeFeatures(cu_wave, wave_data.SampFreq(), + &cu_features, &cu_ivector); + cudaDeviceSynchronize(); + nvtxRangePop(); + + Matrix features(cu_features.NumRows(), cu_features.NumCols()); + Vector ivector(cu_ivector.Dim()); + + features.CopyFromMat(cu_features); + ivector.CopyFromVec(cu_ivector); + + feature_writer.Write(utt, features); + ivector_writer.Write(utt, ivector); + + num_success++; + } catch (...) { + KALDI_WARN << "Failed to compute features for utterance " + << utt; + continue; + } + } + KALDI_LOG << "Processed " << num_utts << " utterances with " + << num_utts - num_success << " failures."; + return (num_success != 0 ? 0 : 1); + + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } + +}