Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 0 additions & 73 deletions src/nnet3/nnet-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,6 @@ void ZeroComponentStats(Nnet *nnet) {
}
}

void ScaleLearningRate(BaseFloat learning_rate_scale,
Nnet *nnet) {
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
if (comp->Properties() & kUpdatableComponent) {
// For now all updatable components inherit from class UpdatableComponent.
// If that changes in future, we will change this code.
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(comp);
if (uc == NULL)
KALDI_ERR << "Updatable component does not inherit from class "
"UpdatableComponent; change this code.";
uc->SetActualLearningRate(uc->LearningRate() * learning_rate_scale);
}
}
}

void SetLearningRate(BaseFloat learning_rate,
Nnet *nnet) {
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Expand All @@ -282,63 +266,6 @@ void SetLearningRate(BaseFloat learning_rate,
}
}

void SetLearningRates(const Vector<BaseFloat> &learning_rates,
Nnet *nnet) {
int32 i = 0;
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
if (comp->Properties() & kUpdatableComponent) {
// For now all updatable components inherit from class UpdatableComponent.
// If that changes in future, we will change this code.
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(comp);
if (uc == NULL)
KALDI_ERR << "Updatable component does not inherit from class "
"UpdatableComponent; change this code.";
KALDI_ASSERT(i < learning_rates.Dim());
uc->SetActualLearningRate(learning_rates(i++));
}
}
KALDI_ASSERT(i == learning_rates.Dim());
}

void GetLearningRates(const Nnet &nnet,
Vector<BaseFloat> *learning_rates) {
learning_rates->Resize(NumUpdatableComponents(nnet));
int32 i = 0;
for (int32 c = 0; c < nnet.NumComponents(); c++) {
const Component *comp = nnet.GetComponent(c);
if (comp->Properties() & kUpdatableComponent) {
// For now all updatable components inherit from class UpdatableComponent.
// If that changes in future, we will change this code.
const UpdatableComponent *uc = dynamic_cast<const UpdatableComponent*>(comp);
if (uc == NULL)
KALDI_ERR << "Updatable component does not inherit from class "
"UpdatableComponent; change this code.";
(*learning_rates)(i++) = uc->LearningRate();
}
}
KALDI_ASSERT(i == learning_rates->Dim());
}

void ScaleNnetComponents(const Vector<BaseFloat> &scale_factors,
Nnet *nnet) {
int32 i = 0;
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
if (comp->Properties() & kUpdatableComponent) {
// For now all updatable components inherit from class UpdatableComponent.
// If that changes in future, we will change this code.
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(comp);
if (uc == NULL)
KALDI_ERR << "Updatable component does not inherit from class "
"UpdatableComponent; change this code.";
KALDI_ASSERT(i < scale_factors.Dim());
uc->Scale(scale_factors(i++));
}
}
KALDI_ASSERT(i == scale_factors.Dim());
}

void ScaleNnet(BaseFloat scale, Nnet *nnet) {
if (scale == 1.0) return;
else if (scale == 0.0) {
Expand Down
22 changes: 0 additions & 22 deletions src/nnet3/nnet-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,9 @@ void ComputeSimpleNnetContext(const Nnet &nnet,
void SetLearningRate(BaseFloat learning_rate,
Nnet *nnet);

/// Scales the actual learning rate for all the components in the nnet
/// by this factor
void ScaleLearningRate(BaseFloat learning_rate_scale,
Nnet *nnet);

/// Sets the actual learning rates for all the updatable components in the
/// neural net to the values in 'learning_rates' vector
/// (one for each updatable component).
void SetLearningRates(const Vector<BaseFloat> &learning_rates,
Nnet *nnet);

/// Get the learning rates for all the updatable components in the neural net
/// (the output must have dim equal to the number of updatable components).
void GetLearningRates(const Nnet &nnet,
Vector<BaseFloat> *learning_rates);

/// Scales the nnet parameters and stats by this scale.
void ScaleNnet(BaseFloat scale, Nnet *nnet);

/// Scales the parameters of each of the updatable components.
/// Here, scales is a vector of size equal to the number of updatable
/// components
void ScaleNnetComponents(const Vector<BaseFloat> &scales,
Nnet *nnet);

/// Does *dest += alpha * src (affects nnet parameters and
/// stored stats).
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest);
Expand Down
8 changes: 0 additions & 8 deletions src/nnet3bin/nnet3-am-copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ int main(int argc, char *argv[]) {
po.Register("learning-rate", &learning_rate,
"If supplied, all the learning rates of updatable components"
" are set to this value.");
po.Register("learning-rate-scale", &learning_rate_scale,
"Scales the learning rate of updatable components by this "
"factor");
po.Register("scale", &scale, "The parameter matrices are scaled"
" by the specified value.");

Expand Down Expand Up @@ -124,11 +121,6 @@ int main(int argc, char *argv[]) {
if (learning_rate >= 0)
SetLearningRate(learning_rate, &(am_nnet.GetNnet()));

KALDI_ASSERT(learning_rate_scale >= 0.0);

if (learning_rate_scale != 1.0)
ScaleLearningRate(learning_rate_scale, &(am_nnet.GetNnet()));

if (!edits_config.empty()) {
Input ki(edits_config);
ReadEditConfig(ki.Stream(), &(am_nnet.GetNnet()));
Expand Down