diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index d73f86d7542..444a3d1a6ce 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -251,9 +251,13 @@ void NonlinearComponent::ZeroStats() { std::string NonlinearComponent::Info() const { std::stringstream stream; - KALDI_ASSERT(InputDim() == OutputDim()); // always the case - stream << Type() << ", dim=" << InputDim(); - + if (InputDim() == OutputDim()) + stream << Type() << ", dim=" << InputDim(); + else + stream << Type() << ", input-dim=" << InputDim() + << ", output-dim=" << OutputDim() + << ", add-log-stddev=true"; + if (count_ > 0 && value_sum_.Dim() == dim_ && deriv_sum_.Dim() == dim_) { stream << ", count=" << std::setprecision(3) << count_ << std::setprecision(6); diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 86f4739e30f..60fd245e6cf 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -240,26 +240,30 @@ void ElementwiseProductComponent::Write(std::ostream &os, bool binary) const { } const BaseFloat NormalizeComponent::kNormFloor = pow(2.0, -66); -// This component modifies the vector of activations by scaling it so that the -// root-mean-square equals 1.0. It's important that its square root -// be exactly representable in float. -void NormalizeComponent::Init(int32 dim, BaseFloat target_rms) { +// This component modifies the vector of activations by scaling it +// so that the root-mean-square equals 1.0. It's important that its +// square root be exactly representable in float. +void NormalizeComponent::Init(int32 dim, BaseFloat target_rms, + bool add_log_stddev) { KALDI_ASSERT(dim > 0); KALDI_ASSERT(target_rms > 0); dim_ = dim; count_ = 0.0; target_rms_ = target_rms; + add_log_stddev_ = add_log_stddev; } void NormalizeComponent::InitFromConfig(ConfigLine *cfl) { int32 dim = 0; + bool add_log_stddev = false; BaseFloat target_rms = 1.0; bool ok = cfl->GetValue("dim", &dim); cfl->GetValue("target-rms", &target_rms); + cfl->GetValue("add-log-stddev", &add_log_stddev); if (!ok || cfl->HasUnusedValues() || dim <= 0 || target_rms <= 0.0) KALDI_ERR << "Invalid initializer for layer of type " << Type() << ": \"" << cfl->WholeLine() << "\""; - Init(dim, target_rms); + Init(dim, target_rms, add_log_stddev); } void NormalizeComponent::Read(std::istream &is, bool binary) { std::ostringstream ostr_beg, ostr_end; @@ -275,6 +279,12 @@ void NormalizeComponent::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &target_rms_); ReadToken(is, binary, &tok); } + // Read add_log_stddev_ token, if it sets. + if (tok == "") { + ReadBasicType(is, binary, &add_log_stddev_); + ReadToken(is, binary, &tok); + } + // The new format is more readable as we write values that are normalized by // the count. KALDI_ASSERT(tok == ""); @@ -297,6 +307,8 @@ void NormalizeComponent::Write(std::ostream &os, bool binary) const { WriteBasicType(os, binary, dim_); WriteToken(os, binary, ""); WriteBasicType(os, binary, target_rms_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, add_log_stddev_); // Write the values and derivatives in a count-normalized way, for // greater readability in text form. WriteToken(os, binary, ""); @@ -317,7 +329,11 @@ void NormalizeComponent::Write(std::ostream &os, bool binary) const { std::string NormalizeComponent::Info() const { std::ostringstream stream; stream << NonlinearComponent::Info(); - stream << ", target-rms=" << target_rms_; + stream << ", target-rms=" << target_rms_ + << ", add-log-stddev=" << add_log_stddev_; + if (add_log_stddev_) + stream << ", input-dim=" << InputDim() + << ", output-dim=" << OutputDim(); return stream.str(); } @@ -328,17 +344,31 @@ std::string NormalizeComponent::Info() const { // there is also flooring involved, to avoid division-by-zero // problems. It's important for the backprop, that the floor's // square root is exactly representable as float. +// If add_log_stddev_ is true, log(max(epsi, sqrt(x^t x / D))) +// is an extra dimension of the output. void NormalizeComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, CuMatrixBase *out) const { + KALDI_ASSERT(out->NumCols() == in.NumCols() + (add_log_stddev_ ? 1 : 0)); + CuSubMatrix out_no_log = out->ColRange(0, in.NumCols()); + out_no_log.CopyFromMat(in); CuVector in_norm(in.NumRows()); BaseFloat d_scaled = (in.NumCols() * target_rms_ * target_rms_); in_norm.AddDiagMat2(1.0 / d_scaled, in, kNoTrans, 0.0); + + if (add_log_stddev_) { + CuVector log_stddev(in.NumRows()); + // log_stddev is log(max(epsi, sqrt(row_in^T row_in / D))). + log_stddev.AddVec(target_rms_ * target_rms_, in_norm, 0.0); + log_stddev.ApplyPow(0.5); + log_stddev.ApplyFloor(kNormFloor); + log_stddev.ApplyLog(); + out->CopyColFromVec(log_stddev, in.NumCols()); + } in_norm.ApplyFloor(kNormFloor); in_norm.ApplyPow(-0.5); - out->CopyFromMat(in); - out->MulRowsVec(in_norm); + out_no_log.MulRowsVec(in_norm); } /* @@ -360,7 +390,9 @@ void NormalizeComponent::Propagate(const ComponentPrecomputedIndexes *indexes, dF/df df/dp dp/d(row_in) = 2/(D * target_rms^2) (f == 1.0 / sqrt(kNormFloor) ? 0.0 : -0.5 f^3) (deriv_out^T row_in) row_in So deriv_in = f deriv_out + (f == 1.0 ? 0.0 : -f^3 / (D * target_rms^2) ) (deriv_out^T row_in) row_in - + + if add_log_stddev_ true, the deriv_in has another term as + dF/dx_i = dF/df . df/dx_i => df/dx_i = x_i/(x^T x) */ void NormalizeComponent::Backprop(const std::string &debug_info, const ComponentPrecomputedIndexes *indexes, @@ -370,19 +402,36 @@ void NormalizeComponent::Backprop(const std::string &debug_info, Component *to_update, CuMatrixBase *in_deriv) const { if (!in_deriv) return; + CuSubMatrix out_deriv_no_log = out_deriv.ColRange(0, + (out_deriv.NumCols() - (add_log_stddev_ ? 1 : 0))); CuVector dot_products(out_deriv.NumRows()); - dot_products.AddDiagMatMat(1.0, out_deriv, kNoTrans, in_value, kTrans, 0.0); + dot_products.AddDiagMatMat(1.0, out_deriv_no_log, kNoTrans, in_value, kTrans, 0.0); CuVector in_norm(in_value.NumRows()); - // dscaled == D * target_rms^2. BaseFloat d_scaled = (in_value.NumCols() * target_rms_ * target_rms_); - in_norm.AddDiagMat2(1.0 / d_scaled, + in_norm.AddDiagMat2(1.0, in_value, kNoTrans, 0.0); + + if (add_log_stddev_) { + CuVector log_stddev_deriv(in_norm), // log_stddev deriv as dF/dy .* (x^T x)^-1 + out_deriv_for_stddev(out_deriv.NumRows()); + // f = log((epsi < sqrt(x^T x / D) ? sqrt(x^T x / D) : epsi) + // => f = log( epsi^2 * D < x^T x ? sqrt(x^T x / D) : epsi) + BaseFloat new_knorm_floor = in_value.NumCols() * kNormFloor * kNormFloor; + log_stddev_deriv.ApplyFloor(new_knorm_floor); + log_stddev_deriv.ApplyPow(-1.0); + out_deriv_for_stddev.CopyColFromMat(out_deriv, (out_deriv.NumCols() - 1)); + log_stddev_deriv.MulElements(out_deriv_for_stddev); + if (in_deriv) + in_deriv->AddDiagVecMat(1.0, log_stddev_deriv, in_value, kNoTrans, 0.0); + } + + in_norm.Scale(1.0 / d_scaled); in_norm.ApplyFloor(kNormFloor); in_norm.ApplyPow(-0.5); if (in_deriv) { if (in_deriv->Data() != out_deriv.Data()) - in_deriv->AddDiagVecMat(1.0, in_norm, out_deriv, kNoTrans, 0.0); + in_deriv->AddDiagVecMat(1.0, in_norm, out_deriv_no_log, kNoTrans, (add_log_stddev_ ? 1.0 : 0.0)); else in_deriv->MulRowsVec(in_norm); } @@ -392,6 +441,7 @@ void NormalizeComponent::Backprop(const std::string &debug_info, in_deriv->AddDiagVecMat(-1.0 / d_scaled, dot_products, in_value, kNoTrans, 1.0); + } void SigmoidComponent::Propagate(const ComponentPrecomputedIndexes *indexes, diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 388f2666740..acf79e3f7df 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -120,15 +120,17 @@ class NormalizeComponent: public NonlinearComponent { // note: although we inherit from NonlinearComponent, we don't actually bohter // accumulating the stats that NonlinearComponent is capable of accumulating. public: - void Init(int32 dim, BaseFloat target_rms); - explicit NormalizeComponent(int32 dim, BaseFloat target_rms = 1.0) { Init(dim, target_rms); } + void Init(int32 dim, BaseFloat target_rms, bool add_log_stddev); + explicit NormalizeComponent(int32 dim, BaseFloat target_rms = 1.0, + bool add_log_stddev = false) { Init(dim, target_rms, add_log_stddev); } explicit NormalizeComponent(const NormalizeComponent &other): NonlinearComponent(other), - target_rms_(other.target_rms_) { } + target_rms_(other.target_rms_), add_log_stddev_(other.add_log_stddev_) { } virtual int32 Properties() const { - return kSimpleComponent|kBackpropNeedsInput|kPropagateInPlace| - kBackpropInPlace; + return (add_log_stddev_ ? kSimpleComponent|kBackpropNeedsInput : + kSimpleComponent|kBackpropNeedsInput|kPropagateInPlace| + kBackpropInPlace); } - NormalizeComponent(): target_rms_(1.0) { } + NormalizeComponent(): target_rms_(1.0), add_log_stddev_(false) { } virtual std::string Type() const { return "NormalizeComponent"; } virtual void InitFromConfig(ConfigLine *cfl); virtual Component* Copy() const { return new NormalizeComponent(*this); } @@ -149,6 +151,8 @@ class NormalizeComponent: public NonlinearComponent { /// Write component to stream virtual void Write(std::ostream &os, bool binary) const; + virtual int32 OutputDim() const { return (dim_ + (add_log_stddev_ ? 1 : 0)); } + virtual std::string Info() const; private: NormalizeComponent &operator = (const NormalizeComponent &other); // Disallow. @@ -157,6 +161,9 @@ class NormalizeComponent: public NonlinearComponent { // about 0.7e-20. We need a value that's exactly representable in // float and whose inverse square root is also exactly representable // in float (hence, an even power of two). + + bool add_log_stddev_; // If true, log(max(epsi, sqrt(row_in^T row_in / D))) + // is an extra dimension of the output. }; diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index 8286b7d8782..7815b7a98ee 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -889,9 +889,11 @@ static void GenerateRandomComponentConfig(std::string *component_type, } case 1: { BaseFloat target_rms = (RandInt(1, 200) / 100.0); + std::string add_log_stddev = (Rand() % 2 == 0 ? "True" : "False"); *component_type = "NormalizeComponent"; os << "dim=" << RandInt(1, 50) - << " target-rms=" << target_rms; + << " target-rms=" << target_rms + << " add-log-stddev=" << add_log_stddev; break; } case 2: {