diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index d65193d9a54..d09c18b6ada 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -250,22 +250,6 @@ void ZeroComponentStats(Nnet *nnet) { } } -void ScaleLearningRate(BaseFloat learning_rate_scale, - Nnet *nnet) { - for (int32 c = 0; c < nnet->NumComponents(); c++) { - Component *comp = nnet->GetComponent(c); - if (comp->Properties() & kUpdatableComponent) { - // For now all updatable components inherit from class UpdatableComponent. - // If that changes in future, we will change this code. - UpdatableComponent *uc = dynamic_cast(comp); - if (uc == NULL) - KALDI_ERR << "Updatable component does not inherit from class " - "UpdatableComponent; change this code."; - uc->SetActualLearningRate(uc->LearningRate() * learning_rate_scale); - } - } -} - void SetLearningRate(BaseFloat learning_rate, Nnet *nnet) { for (int32 c = 0; c < nnet->NumComponents(); c++) { @@ -282,63 +266,6 @@ void SetLearningRate(BaseFloat learning_rate, } } -void SetLearningRates(const Vector &learning_rates, - Nnet *nnet) { - int32 i = 0; - for (int32 c = 0; c < nnet->NumComponents(); c++) { - Component *comp = nnet->GetComponent(c); - if (comp->Properties() & kUpdatableComponent) { - // For now all updatable components inherit from class UpdatableComponent. - // If that changes in future, we will change this code. - UpdatableComponent *uc = dynamic_cast(comp); - if (uc == NULL) - KALDI_ERR << "Updatable component does not inherit from class " - "UpdatableComponent; change this code."; - KALDI_ASSERT(i < learning_rates.Dim()); - uc->SetActualLearningRate(learning_rates(i++)); - } - } - KALDI_ASSERT(i == learning_rates.Dim()); -} - -void GetLearningRates(const Nnet &nnet, - Vector *learning_rates) { - learning_rates->Resize(NumUpdatableComponents(nnet)); - int32 i = 0; - for (int32 c = 0; c < nnet.NumComponents(); c++) { - const Component *comp = nnet.GetComponent(c); - if (comp->Properties() & kUpdatableComponent) { - // For now all updatable components inherit from class UpdatableComponent. - // If that changes in future, we will change this code. - const UpdatableComponent *uc = dynamic_cast(comp); - if (uc == NULL) - KALDI_ERR << "Updatable component does not inherit from class " - "UpdatableComponent; change this code."; - (*learning_rates)(i++) = uc->LearningRate(); - } - } - KALDI_ASSERT(i == learning_rates->Dim()); -} - -void ScaleNnetComponents(const Vector &scale_factors, - Nnet *nnet) { - int32 i = 0; - for (int32 c = 0; c < nnet->NumComponents(); c++) { - Component *comp = nnet->GetComponent(c); - if (comp->Properties() & kUpdatableComponent) { - // For now all updatable components inherit from class UpdatableComponent. - // If that changes in future, we will change this code. - UpdatableComponent *uc = dynamic_cast(comp); - if (uc == NULL) - KALDI_ERR << "Updatable component does not inherit from class " - "UpdatableComponent; change this code."; - KALDI_ASSERT(i < scale_factors.Dim()); - uc->Scale(scale_factors(i++)); - } - } - KALDI_ASSERT(i == scale_factors.Dim()); -} - void ScaleNnet(BaseFloat scale, Nnet *nnet) { if (scale == 1.0) return; else if (scale == 0.0) { diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index 1e0dcefd703..9cbfa87a800 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -116,31 +116,9 @@ void ComputeSimpleNnetContext(const Nnet &nnet, void SetLearningRate(BaseFloat learning_rate, Nnet *nnet); -/// Scales the actual learning rate for all the components in the nnet -/// by this factor -void ScaleLearningRate(BaseFloat learning_rate_scale, - Nnet *nnet); - -/// Sets the actual learning rates for all the updatable components in the -/// neural net to the values in 'learning_rates' vector -/// (one for each updatable component). -void SetLearningRates(const Vector &learning_rates, - Nnet *nnet); - -/// Get the learning rates for all the updatable components in the neural net -/// (the output must have dim equal to the number of updatable components). -void GetLearningRates(const Nnet &nnet, - Vector *learning_rates); - /// Scales the nnet parameters and stats by this scale. void ScaleNnet(BaseFloat scale, Nnet *nnet); -/// Scales the parameters of each of the updatable components. -/// Here, scales is a vector of size equal to the number of updatable -/// components -void ScaleNnetComponents(const Vector &scales, - Nnet *nnet); - /// Does *dest += alpha * src (affects nnet parameters and /// stored stats). void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest); diff --git a/src/nnet3bin/nnet3-am-copy.cc b/src/nnet3bin/nnet3-am-copy.cc index 4851f839dcb..7aa0e4a32c0 100644 --- a/src/nnet3bin/nnet3-am-copy.cc +++ b/src/nnet3bin/nnet3-am-copy.cc @@ -80,9 +80,6 @@ int main(int argc, char *argv[]) { po.Register("learning-rate", &learning_rate, "If supplied, all the learning rates of updatable components" " are set to this value."); - po.Register("learning-rate-scale", &learning_rate_scale, - "Scales the learning rate of updatable components by this " - "factor"); po.Register("scale", &scale, "The parameter matrices are scaled" " by the specified value."); @@ -124,11 +121,6 @@ int main(int argc, char *argv[]) { if (learning_rate >= 0) SetLearningRate(learning_rate, &(am_nnet.GetNnet())); - KALDI_ASSERT(learning_rate_scale >= 0.0); - - if (learning_rate_scale != 1.0) - ScaleLearningRate(learning_rate_scale, &(am_nnet.GetNnet())); - if (!edits_config.empty()) { Input ki(edits_config); ReadEditConfig(ki.Stream(), &(am_nnet.GetNnet()));