Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/nnet3/nnet-component-itf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,13 @@ void NonlinearComponent::ZeroStats() {

std::string NonlinearComponent::Info() const {
std::stringstream stream;
KALDI_ASSERT(InputDim() == OutputDim()); // always the case
stream << Type() << ", dim=" << InputDim();

if (InputDim() == OutputDim())
stream << Type() << ", dim=" << InputDim();
else
stream << Type() << ", input-dim=" << InputDim()
<< ", output-dim=" << OutputDim()
<< ", add-log-stddev=true";

if (count_ > 0 && value_sum_.Dim() == dim_ && deriv_sum_.Dim() == dim_) {
stream << ", count=" << std::setprecision(3) << count_
<< std::setprecision(6);
Expand Down
76 changes: 63 additions & 13 deletions src/nnet3/nnet-simple-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,30 @@ void ElementwiseProductComponent::Write(std::ostream &os, bool binary) const {
}

const BaseFloat NormalizeComponent::kNormFloor = pow(2.0, -66);
// This component modifies the vector of activations by scaling it so that the
// root-mean-square equals 1.0. It's important that its square root
// be exactly representable in float.
void NormalizeComponent::Init(int32 dim, BaseFloat target_rms) {
// This component modifies the vector of activations by scaling it
// so that the root-mean-square equals 1.0. It's important that its
// square root be exactly representable in float.
void NormalizeComponent::Init(int32 dim, BaseFloat target_rms,
bool add_log_stddev) {
KALDI_ASSERT(dim > 0);
KALDI_ASSERT(target_rms > 0);
dim_ = dim;
count_ = 0.0;
target_rms_ = target_rms;
add_log_stddev_ = add_log_stddev;
}

void NormalizeComponent::InitFromConfig(ConfigLine *cfl) {
int32 dim = 0;
bool add_log_stddev = false;
BaseFloat target_rms = 1.0;
bool ok = cfl->GetValue("dim", &dim);
cfl->GetValue("target-rms", &target_rms);
cfl->GetValue("add-log-stddev", &add_log_stddev);
if (!ok || cfl->HasUnusedValues() || dim <= 0 || target_rms <= 0.0)
KALDI_ERR << "Invalid initializer for layer of type "
<< Type() << ": \"" << cfl->WholeLine() << "\"";
Init(dim, target_rms);
Init(dim, target_rms, add_log_stddev);
}
void NormalizeComponent::Read(std::istream &is, bool binary) {
std::ostringstream ostr_beg, ostr_end;
Expand All @@ -275,6 +279,12 @@ void NormalizeComponent::Read(std::istream &is, bool binary) {
ReadBasicType(is, binary, &target_rms_);
ReadToken(is, binary, &tok);
}
// Read add_log_stddev_ token, if it sets.
if (tok == "<AddLogStddev>") {
ReadBasicType(is, binary, &add_log_stddev_);
ReadToken(is, binary, &tok);
}

// The new format is more readable as we write values that are normalized by
// the count.
KALDI_ASSERT(tok == "<ValueAvg>");
Expand All @@ -297,6 +307,8 @@ void NormalizeComponent::Write(std::ostream &os, bool binary) const {
WriteBasicType(os, binary, dim_);
WriteToken(os, binary, "<TargetRms>");
WriteBasicType(os, binary, target_rms_);
WriteToken(os, binary, "<AddLogStddev>");
WriteBasicType(os, binary, add_log_stddev_);
// Write the values and derivatives in a count-normalized way, for
// greater readability in text form.
WriteToken(os, binary, "<ValueAvg>");
Expand All @@ -317,7 +329,11 @@ void NormalizeComponent::Write(std::ostream &os, bool binary) const {
std::string NormalizeComponent::Info() const {
std::ostringstream stream;
stream << NonlinearComponent::Info();
stream << ", target-rms=" << target_rms_;
stream << ", target-rms=" << target_rms_
<< ", add-log-stddev=" << add_log_stddev_;
if (add_log_stddev_)
stream << ", input-dim=" << InputDim()
<< ", output-dim=" << OutputDim();
return stream.str();
}

Expand All @@ -328,17 +344,31 @@ std::string NormalizeComponent::Info() const {
// there is also flooring involved, to avoid division-by-zero
// problems. It's important for the backprop, that the floor's
// square root is exactly representable as float.
// If add_log_stddev_ is true, log(max(epsi, sqrt(x^t x / D)))
// is an extra dimension of the output.
void NormalizeComponent::Propagate(const ComponentPrecomputedIndexes *indexes,
const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) const {
KALDI_ASSERT(out->NumCols() == in.NumCols() + (add_log_stddev_ ? 1 : 0));
CuSubMatrix<BaseFloat> out_no_log = out->ColRange(0, in.NumCols());
out_no_log.CopyFromMat(in);
CuVector<BaseFloat> in_norm(in.NumRows());
BaseFloat d_scaled = (in.NumCols() * target_rms_ * target_rms_);
in_norm.AddDiagMat2(1.0 / d_scaled,
in, kNoTrans, 0.0);

if (add_log_stddev_) {
CuVector<BaseFloat> log_stddev(in.NumRows());
// log_stddev is log(max(epsi, sqrt(row_in^T row_in / D))).
log_stddev.AddVec(target_rms_ * target_rms_, in_norm, 0.0);
log_stddev.ApplyPow(0.5);
log_stddev.ApplyFloor(kNormFloor);
log_stddev.ApplyLog();
out->CopyColFromVec(log_stddev, in.NumCols());
}
in_norm.ApplyFloor(kNormFloor);
in_norm.ApplyPow(-0.5);
out->CopyFromMat(in);
out->MulRowsVec(in_norm);
out_no_log.MulRowsVec(in_norm);
}

/*
Expand All @@ -360,7 +390,9 @@ void NormalizeComponent::Propagate(const ComponentPrecomputedIndexes *indexes,
dF/df df/dp dp/d(row_in) = 2/(D * target_rms^2) (f == 1.0 / sqrt(kNormFloor) ? 0.0 : -0.5 f^3) (deriv_out^T row_in) row_in
So
deriv_in = f deriv_out + (f == 1.0 ? 0.0 : -f^3 / (D * target_rms^2) ) (deriv_out^T row_in) row_in


if add_log_stddev_ true, the deriv_in has another term as
dF/dx_i = dF/df . df/dx_i => df/dx_i = x_i/(x^T x)
*/
void NormalizeComponent::Backprop(const std::string &debug_info,
const ComponentPrecomputedIndexes *indexes,
Expand All @@ -370,19 +402,36 @@ void NormalizeComponent::Backprop(const std::string &debug_info,
Component *to_update,
CuMatrixBase<BaseFloat> *in_deriv) const {
if (!in_deriv) return;
CuSubMatrix<BaseFloat> out_deriv_no_log = out_deriv.ColRange(0,
(out_deriv.NumCols() - (add_log_stddev_ ? 1 : 0)));
CuVector<BaseFloat> dot_products(out_deriv.NumRows());
dot_products.AddDiagMatMat(1.0, out_deriv, kNoTrans, in_value, kTrans, 0.0);
dot_products.AddDiagMatMat(1.0, out_deriv_no_log, kNoTrans, in_value, kTrans, 0.0);
CuVector<BaseFloat> in_norm(in_value.NumRows());
// dscaled == D * target_rms^2.
BaseFloat d_scaled = (in_value.NumCols() * target_rms_ * target_rms_);
in_norm.AddDiagMat2(1.0 / d_scaled,
in_norm.AddDiagMat2(1.0,
in_value, kNoTrans, 0.0);

if (add_log_stddev_) {
CuVector<BaseFloat> log_stddev_deriv(in_norm), // log_stddev deriv as dF/dy .* (x^T x)^-1
out_deriv_for_stddev(out_deriv.NumRows());
// f = log((epsi < sqrt(x^T x / D) ? sqrt(x^T x / D) : epsi)
// => f = log( epsi^2 * D < x^T x ? sqrt(x^T x / D) : epsi)
BaseFloat new_knorm_floor = in_value.NumCols() * kNormFloor * kNormFloor;
log_stddev_deriv.ApplyFloor(new_knorm_floor);
log_stddev_deriv.ApplyPow(-1.0);
out_deriv_for_stddev.CopyColFromMat(out_deriv, (out_deriv.NumCols() - 1));
log_stddev_deriv.MulElements(out_deriv_for_stddev);
if (in_deriv)
in_deriv->AddDiagVecMat(1.0, log_stddev_deriv, in_value, kNoTrans, 0.0);
}

in_norm.Scale(1.0 / d_scaled);
in_norm.ApplyFloor(kNormFloor);
in_norm.ApplyPow(-0.5);

if (in_deriv) {
if (in_deriv->Data() != out_deriv.Data())
in_deriv->AddDiagVecMat(1.0, in_norm, out_deriv, kNoTrans, 0.0);
in_deriv->AddDiagVecMat(1.0, in_norm, out_deriv_no_log, kNoTrans, (add_log_stddev_ ? 1.0 : 0.0));
else
in_deriv->MulRowsVec(in_norm);
}
Expand All @@ -392,6 +441,7 @@ void NormalizeComponent::Backprop(const std::string &debug_info,
in_deriv->AddDiagVecMat(-1.0 / d_scaled,
dot_products, in_value,
kNoTrans, 1.0);

}

void SigmoidComponent::Propagate(const ComponentPrecomputedIndexes *indexes,
Expand Down
19 changes: 13 additions & 6 deletions src/nnet3/nnet-simple-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,17 @@ class NormalizeComponent: public NonlinearComponent {
// note: although we inherit from NonlinearComponent, we don't actually bohter
// accumulating the stats that NonlinearComponent is capable of accumulating.
public:
void Init(int32 dim, BaseFloat target_rms);
explicit NormalizeComponent(int32 dim, BaseFloat target_rms = 1.0) { Init(dim, target_rms); }
void Init(int32 dim, BaseFloat target_rms, bool add_log_stddev);
explicit NormalizeComponent(int32 dim, BaseFloat target_rms = 1.0,
bool add_log_stddev = false) { Init(dim, target_rms, add_log_stddev); }
explicit NormalizeComponent(const NormalizeComponent &other): NonlinearComponent(other),
target_rms_(other.target_rms_) { }
target_rms_(other.target_rms_), add_log_stddev_(other.add_log_stddev_) { }
virtual int32 Properties() const {
return kSimpleComponent|kBackpropNeedsInput|kPropagateInPlace|
kBackpropInPlace;
return (add_log_stddev_ ? kSimpleComponent|kBackpropNeedsInput :
kSimpleComponent|kBackpropNeedsInput|kPropagateInPlace|
kBackpropInPlace);
}
NormalizeComponent(): target_rms_(1.0) { }
NormalizeComponent(): target_rms_(1.0), add_log_stddev_(false) { }
virtual std::string Type() const { return "NormalizeComponent"; }
virtual void InitFromConfig(ConfigLine *cfl);
virtual Component* Copy() const { return new NormalizeComponent(*this); }
Expand All @@ -149,6 +151,8 @@ class NormalizeComponent: public NonlinearComponent {
/// Write component to stream
virtual void Write(std::ostream &os, bool binary) const;

virtual int32 OutputDim() const { return (dim_ + (add_log_stddev_ ? 1 : 0)); }

virtual std::string Info() const;
private:
NormalizeComponent &operator = (const NormalizeComponent &other); // Disallow.
Expand All @@ -157,6 +161,9 @@ class NormalizeComponent: public NonlinearComponent {
// about 0.7e-20. We need a value that's exactly representable in
// float and whose inverse square root is also exactly representable
// in float (hence, an even power of two).

bool add_log_stddev_; // If true, log(max(epsi, sqrt(row_in^T row_in / D)))
// is an extra dimension of the output.
};


Expand Down
4 changes: 3 additions & 1 deletion src/nnet3/nnet-test-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,11 @@ static void GenerateRandomComponentConfig(std::string *component_type,
}
case 1: {
BaseFloat target_rms = (RandInt(1, 200) / 100.0);
std::string add_log_stddev = (Rand() % 2 == 0 ? "True" : "False");
*component_type = "NormalizeComponent";
os << "dim=" << RandInt(1, 50)
<< " target-rms=" << target_rms;
<< " target-rms=" << target_rms
<< " add-log-stddev=" << add_log_stddev;
break;
}
case 2: {
Expand Down