diff --git a/src/nnet3/nnet-general-component.cc b/src/nnet3/nnet-general-component.cc index b1a2d9327f8..27955215c98 100644 --- a/src/nnet3/nnet-general-component.cc +++ b/src/nnet3/nnet-general-component.cc @@ -19,6 +19,7 @@ #include #include +#include #include "nnet3/nnet-general-component.h" #include "nnet3/nnet-computation-graph.h" #include "nnet3/nnet-parse.h" @@ -957,6 +958,7 @@ void BackpropTruncationComponentPrecomputedIndexes::Read(std::istream &istream, std::string BackpropTruncationComponent::Info() const { std::ostringstream stream; stream << Type() << ", dim=" << dim_ + << ", count=" << std::setprecision(3) << count_ << std::setprecision(6) << ", clipping-threshold=" << clipping_threshold_ << ", clipped-proportion=" << (count_ > 0.0 ? num_clipped_ / count_ : 0) diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 58908a0fe09..53f87f43738 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -5110,6 +5110,13 @@ Component* LstmNonlinearityComponent::Copy() const { return new LstmNonlinearityComponent(*this); } +void LstmNonlinearityComponent::ZeroStats() { + value_sum_.SetZero(); + deriv_sum_.SetZero(); + self_repair_total_.SetZero(); + count_ = 0.0; +} + void LstmNonlinearityComponent::Scale(BaseFloat scale) { params_.Scale(scale); value_sum_.Scale(scale); diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index f09a989759a..47a4510526c 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -1744,6 +1744,7 @@ class LstmNonlinearityComponent: public UpdatableComponent { virtual int32 NumParameters() const; virtual void Vectorize(VectorBase *params) const; virtual void UnVectorize(const VectorBase ¶ms); + virtual void ZeroStats(); // Some functions that are specific to this class: explicit LstmNonlinearityComponent(