From bf3257413242e3f3ff0e5b3ff066320a4495013e Mon Sep 17 00:00:00 2001 From: Hossein Hadian Date: Mon, 12 Mar 2018 13:09:52 -0400 Subject: [PATCH] end2end chain: Fix 2 issues Re bad chain.l2_term and numerator backward computation --- src/chain/chain-training.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/chain/chain-training.cc b/src/chain/chain-training.cc index 59fdfb4bbc0..1d357ace106 100644 --- a/src/chain/chain-training.cc +++ b/src/chain/chain-training.cc @@ -89,7 +89,9 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts, if (nnet_output_deriv) nnet_output_deriv->AddMat(1.0, *xent_output_deriv); } else if (nnet_output_deriv && numerator_ok) { - numerator.Backward(nnet_output_deriv); + numerator_ok = numerator.Backward(nnet_output_deriv); + if (!numerator_ok) + KALDI_LOG << "Numerator backward failed."; } } @@ -128,9 +130,8 @@ void ComputeChainObjfAndDerivE2e(const ChainTrainingOptions &opts, KALDI_LOG << "Derivs per frame are " << row_products_per_frame; } - if (opts.l2_regularize == 0.0) { - *l2_term = 0.0; - } else if (numerator_ok) { // we should have some derivs to include a L2 term + *l2_term = 0.0; + if (opts.l2_regularize != 0.0 && numerator_ok) { // we should have some derivs to include a L2 term // compute the l2 penalty term and its derivative BaseFloat scale = supervision.weight * opts.l2_regularize; *l2_term = -0.5 * scale * TraceMatMat(nnet_output, nnet_output, kTrans);