diff --git a/src/Makefile b/src/Makefile index 766b6037b1b..96386c4650d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -171,9 +171,10 @@ ivector: base util matrix transform tree gmm #3)Dependencies for optional parts of Kaldi onlinebin: base matrix util feat tree gmm transform sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online # python-kaldi-decoding: base matrix util feat tree gmm transform sgmm2 fstext hmm decoder lat online -cudafeat: base matrix util gmm transform tree feat cudamatrix +cudafeat: base matrix util gmm transform tree feat cudamatrix online2 +cudafeatbin: base matrix util gmm transform tree feat cudamatrix cudafeat online2 online: decoder gmm transform feat matrix util base lat hmm tree online2: decoder gmm transform feat matrix util base lat hmm tree ivector cudamatrix nnet2 nnet3 chain kws: base util hmm tree matrix lat -cudadecoder: cudamatrix online2 nnet3 ivector feat fstext lat chain transform -cudadecoderbin: cudadecoder cudamatrix online2 nnet3 ivector feat fstext lat chain transform +cudadecoder: cudamatrix cudafeat online2 nnet3 ivector feat fstext lat chain transform +cudadecoderbin: cudadecoder cudafeat cudamatrix online2 nnet3 ivector feat fstext lat chain transform diff --git a/src/cudafeat/Makefile b/src/cudafeat/Makefile index b8ae247f87f..913c1ea9dbb 100644 --- a/src/cudafeat/Makefile +++ b/src/cudafeat/Makefile @@ -7,15 +7,15 @@ ifeq ($(CUDA), true) TESTFILES = - -OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o - +ifeq ($(CUDA), true) + OBJFILES += feature-window-cuda.o feature-mfcc-cuda.o feature-online-cmvn-cuda.o +endif LIBNAME = kaldi-cudafeat ADDLIBS = ../feat/kaldi-feat.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ ../base/kaldi-base.a ../cudamatrix/kaldi-cudamatrix.a \ - ../gmm/kaldi-gmm.a + ../gmm/kaldi-gmm.a ../online2/kaldi-online2.a LDFLAGS += $(CUDA_LDFLAGS) LDLIBS += $(CUDA_LDLIBS) diff --git a/src/cudafeat/feature-online-cmvn-cuda.cu b/src/cudafeat/feature-online-cmvn-cuda.cu new file mode 100644 index 00000000000..22274b697b3 --- /dev/null +++ b/src/cudafeat/feature-online-cmvn-cuda.cu @@ -0,0 +1,203 @@ +// cudafeat/feature-online-cmvn-cuda.cu +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cudafeat/feature-online-cmvn-cuda.h" +#include "cudamatrix/cu-matrix.h" +#include "cudamatrix/cu-vector.h" + +__device__ inline float2 operator-(const float2 &a, const float2 &b) { + float2 retval; + retval.x = a.x - b.x; + retval.y = a.y - b.y; + return retval; +} +__device__ inline float2 operator+(const float2 &a, const float2 &b) { + float2 retval; + retval.x = a.x + b.x; + retval.y = a.y + b.y; + return retval; +} + +#if __CUDA_ARCH__ == 750 +__launch_bounds__ (1024, 1) +#else +__launch_bounds__ (1024, 2) +#endif +__global__ void compute_cmvn_stats_kernel(const float *data, int32_t ldd, + int32_t num_frames, int32_t feat_dim, + float *stats, int32_t lds) { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int32_t feat = blockIdx.x; + + float2 running_sum = {0.0f, 0.0f}; + // for each frame, keep threads alive for cub + for (int32_t r = 0; r < num_frames; r += blockDim.x) { + int32_t rid = r + threadIdx.x; + + float val = 0.0f; + + if (rid < num_frames) { + // uncoalesced, could transpose data or do some shared memory swizzling... + val = data[rid * ldd + feat]; + } + + float2 sum = {val, val * val}; // this elements value and value squared + + float2 psum; // row prefix sum + float2 total; // total count + BlockScan(temp_storage).InclusiveSum(sum, psum, total); + + // offset by running sum + psum = psum + running_sum; + // increase running sum by new total + running_sum = running_sum + total; + + // un-coalesced + if (rid < num_frames) { + reinterpret_cast(&stats[rid * lds])[feat] = psum; + } + } +} + +__global__ void apply_cmvn_kernel( + int32_t cmvn_window, bool var_norm, bool mean_norm, const float *feat_in, + int32_t ldi, int32_t num_rows, int32_t num_cols, + const float *__restrict__ stats, int32_t lds, + const float *__restrict__ global_stats, int32_t ldg, int32_t global_frames, + const float *__restrict__ speaker_stats, int32_t ldss, + int32_t speaker_frames, float *feat_out, int32_t ldo) { + int32_t r = blockIdx.x; + + for (int c = threadIdx.x; c < num_cols; c += blockDim.x) { + float2 frame_stats = + reinterpret_cast(&stats[r * lds])[c]; + + float val = feat_in[r * ldi + c]; + + float window_length = min(r + 1, cmvn_window); + + // we have to subtract row r-cmvn_window stats + if (r >= cmvn_window) { + // window starting row + int32_t o = r - cmvn_window; + + // stats at the start row of the window that must be removed + float2 ostats = + reinterpret_cast(&stats[o * lds])[c]; + + // remove start of the window stats + frame_stats = frame_stats - ostats; + } + + // Smooth stats by speaker frames if necessary + float smooth_frames = cmvn_window - window_length; + if (smooth_frames > 0 && speaker_frames > 0) { + float count_from_speaker = min(smooth_frames, (float)speaker_frames); + float speaker_count = speaker_stats[num_cols]; + + if (count_from_speaker > 0.0) { + float alpha = count_from_speaker / speaker_count; + + frame_stats.x += alpha * speaker_stats[c]; // update mean + frame_stats.y += alpha * speaker_stats[ldss + c]; // update variance + window_length += alpha * speaker_count; // update window length + + // recompute smooth frames now that we have speaker stats + smooth_frames = cmvn_window - window_length; + } + } + + // Smooth stats by global frames if necessary + if (smooth_frames > 0 && global_frames > 0) { + float count_from_global = min(smooth_frames, (float)global_frames); + float global_count = global_stats[num_cols]; + + if (count_from_global > 0.0) { + float alpha = count_from_global / global_count; + + frame_stats.x += alpha * global_stats[c]; // update mean + frame_stats.y += alpha * global_stats[ldg + c]; // update variance + window_length += alpha * global_count; // update window length + } + } + + float mean = frame_stats.x / window_length; + float var = frame_stats.y / window_length - mean * mean; + + float floor = 1e-20; + if (var < floor) // avoid dividing by zero + var = floor; + + if (!var_norm) { + // skip variance normalization + var = 1.0f; + } + if (!mean_norm) { + assert(false); + // skip mean normalization + mean = 0.0f; + } + + // shift by mean and scale by variance + feat_out[r * ldo + c] = (val - mean) / sqrtf(var); + } +} + +namespace kaldi { + +void CudaOnlineCmvn::ComputeFeatures(const CuMatrixBase &feats_in, + CuMatrix *feats_out) { + int32_t num_frames = feats_in.NumRows(); + int32_t feat_dim = feats_in.NumCols(); + feats_out->Resize(num_frames, feat_dim, kUndefined); + + CuMatrix stats(num_frames, feat_dim * 2, kUndefined); + + int threads = 1024; + int blocks = feat_dim; + + // compute windowed sum/sum2 prefix sum along column of feats + compute_cmvn_stats_kernel<<>>( + feats_in.Data(), feats_in.Stride(), num_frames, feat_dim, stats.Data(), + stats.Stride()); + CU_SAFE_CALL(cudaGetLastError()); + + threads = (feat_dim + 31) / 32 * 32; // round up to 32 threads + if (threads > 1024) threads = 1024; + + const CuMatrix &gstats = cmvn_state_.global_cmvn_stats; + const CuMatrix &sstats = cmvn_state_.speaker_cmvn_stats; + + int global_frames = opts_.global_frames; + int speaker_frames = opts_.speaker_frames; + + if (gstats.NumRows() == 0) global_frames = 0; + if (sstats.NumRows() == 0) speaker_frames = 0; + + // apply cmvn + apply_cmvn_kernel<<>>( + opts_.cmn_window, opts_.normalize_variance, opts_.normalize_mean, + feats_in.Data(), feats_in.Stride(), num_frames, feat_dim, stats.Data(), + stats.Stride(), gstats.Data(), gstats.Stride(), global_frames, + sstats.Data(), sstats.Stride(), speaker_frames, feats_out->Data(), + feats_out->Stride()); + CU_SAFE_CALL(cudaGetLastError()); +} +} diff --git a/src/cudafeat/feature-online-cmvn-cuda.h b/src/cudafeat/feature-online-cmvn-cuda.h new file mode 100644 index 00000000000..729467a7a88 --- /dev/null +++ b/src/cudafeat/feature-online-cmvn-cuda.h @@ -0,0 +1,59 @@ +// cudafeat/feature-online-cmvn-cuda.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Justin Luitjens +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_CUDAFEAT_FEATURE_ONLINE_CMVN_CUDA_H_ +#define KALDI_CUDAFEAT_FEATURE_ONLINE_CMVN_CUDA_H_ + +#include "cudamatrix/cu-matrix.h" +#include "cudamatrix/cu-vector.h" +#include "feat/online-feature.h" + +namespace kaldi { + +struct CudaOnlineCmvnState { + // The following is the global CMVN stats, in the usual + // format, of dimension 2 x (dim+1), as [ sum-stats count + // sum-sqared-stats 0 ] + CuMatrix global_cmvn_stats; + CuMatrix speaker_cmvn_stats; + + CudaOnlineCmvnState(){}; + CudaOnlineCmvnState(const OnlineCmvnState &cmvn_state) + : global_cmvn_stats(cmvn_state.global_cmvn_stats), + speaker_cmvn_stats(cmvn_state.speaker_cmvn_stats) {} + + CudaOnlineCmvnState(const CudaOnlineCmvnState &cmvn_state) + : global_cmvn_stats(cmvn_state.global_cmvn_stats), + speaker_cmvn_stats(cmvn_state.speaker_cmvn_stats) {} +}; + +class CudaOnlineCmvn { + public: + CudaOnlineCmvn(const OnlineCmvnOptions &opts, const CudaOnlineCmvnState &cmvn_state) + : opts_(opts), cmvn_state_(cmvn_state){}; + ~CudaOnlineCmvn(){}; + + void ComputeFeatures(const CuMatrixBase &feats_in, + CuMatrix *feats_out); + + private: + const OnlineCmvnOptions &opts_; + const CudaOnlineCmvnState &cmvn_state_; +}; +} + +#endif diff --git a/src/cudafeatbin/Makefile b/src/cudafeatbin/Makefile index 5e4aacad11b..e1af458b62e 100644 --- a/src/cudafeatbin/Makefile +++ b/src/cudafeatbin/Makefile @@ -9,7 +9,7 @@ LDLIBS += $(CUDA_LDLIBS) BINFILES = ifeq ($(CUDA), true) - BINFILES += compute-mfcc-feats-cuda + BINFILES += compute-mfcc-feats-cuda apply-cmvn-online-cuda endif OBJFILES = @@ -20,6 +20,6 @@ ADDLIBS = ../cudafeat/kaldi-cudafeat.a ../cudamatrix/kaldi-cudamatrix.a \ ../hmm/kaldi-hmm.a ../feat/kaldi-feat.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a ../online2/kaldi-online2.a include ../makefiles/default_rules.mk diff --git a/src/cudafeatbin/apply-cmvn-online-cuda.cc b/src/cudafeatbin/apply-cmvn-online-cuda.cc new file mode 100644 index 00000000000..6dc18fdf2ab --- /dev/null +++ b/src/cudafeatbin/apply-cmvn-online-cuda.cc @@ -0,0 +1,107 @@ +// online2bin/apply-cmvn-online.cc + +// Copyright 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "feat/online-feature.h" +#include "cudafeat/feature-online-cmvn-cuda.h" + +int main(int argc, char *argv[]) { + try { + typedef kaldi::int32 int32; + using namespace kaldi; + const char *usage = + "Apply online cepstral mean (and possibly variance) computation online,\n" + "using the same code as used for online decoding in the 'new' setup in\n" + "online2/ and online2bin/.'\n" + "The computation is done on the device in serial. " + "spk2utt is not supported.\n" + "\n" + "Usage: apply-cmvn-online-cuda [options] " + "\n" + "e.g. apply-cmvn-online-cuda 'matrix-sum scp:data/train/cmvn.scp -|' data/train/split8/1/feats.scp ark:-\n"; + + ParseOptions po(usage); + + OnlineCmvnOptions cmvn_opts; + + std::string spk2utt_rspecifier; + cmvn_opts.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + g_cuda_allocator.SetOptions(g_allocator_options); + CuDevice::Instantiate().SelectGpuId("yes"); + CuDevice::Instantiate().AllowMultithreading(); + + std::string global_stats_rxfilename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + feature_wspecifier = po.GetArg(3); + + // global_cmvn_stats helps us initialize to online CMVN to + // reasonable values at the beginning of the utterance. + Matrix global_cmvn_stats; + ReadKaldiObject(global_stats_rxfilename, &global_cmvn_stats); + + BaseFloatMatrixWriter feature_writer(feature_wspecifier); + int32 num_done = 0; + int64 tot_t = 0; + + OnlineCmvnState cmvn_state(global_cmvn_stats); + CudaOnlineCmvnState cu_cmvn_state(cmvn_state); + CudaOnlineCmvn cuda_cmvn(cmvn_opts, cu_cmvn_state); + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string utt = feature_reader.Key(); + const Matrix &feats = feature_reader.Value(); + int32_t numRows = feats.NumRows(); + int32_t numCols = feats.NumCols(); + + CuMatrix cu_feats_in(feats); + CuMatrix cu_feats_out(numRows, numCols, kUndefined); + Matrix normalized_feats(numRows, numCols, kUndefined); + + cuda_cmvn.ComputeFeatures(cu_feats_in, &cu_feats_out); + + normalized_feats.CopyFromMat(cu_feats_out); + + num_done++; + tot_t += feats.NumRows(); + feature_writer.Write(utt, normalized_feats); + + num_done++; + } + + KALDI_LOG << "Applied online CMVN to " << num_done << " files, or " + << tot_t << " frames."; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} +