Skip to content
8 changes: 6 additions & 2 deletions src/nnet3/nnet-general-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ void StatisticsPoolingComponent::InitFromConfig(ConfigLine *cfl) {
StatisticsPoolingComponent::StatisticsPoolingComponent():
input_dim_(-1), input_period_(1), left_context_(-1), right_context_(-1),
num_log_count_features_(0), output_stddevs_(false),
variance_floor_(1.0e-10) { }
variance_floor_(1.0e-10), require_direct_input_(false) { }


StatisticsPoolingComponent::StatisticsPoolingComponent(
Expand All @@ -582,7 +582,8 @@ StatisticsPoolingComponent::StatisticsPoolingComponent(
left_context_(other.left_context_), right_context_(other.right_context_),
num_log_count_features_(other.num_log_count_features_),
output_stddevs_(other.output_stddevs_),
variance_floor_(1.0e-10) {
variance_floor_(other.variance_floor_),
require_direct_input_(other.require_direct_input_) {
Check();
}

Expand Down Expand Up @@ -614,6 +615,9 @@ void StatisticsPoolingComponent::Read(std::istream &is, bool binary) {
ExpectToken(is, binary, "<VarianceFloor>");
ReadBasicType(is, binary, &variance_floor_);
ExpectToken(is, binary, "</StatisticsPoolingComponent>");
require_direct_input_ = false; // This is not written to disk, it's only used
// temporarily, in memory (see
// nnet3-xvector-compute-batched.cc).
Check();
}

Expand Down
15 changes: 14 additions & 1 deletion src/nnet3/nnet-general-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ class StatisticsExtractionComponentPrecomputedIndexes:
or whatever, instead of just component-name, because its output is only defined at multiples
of its input-period.

The output of StatisticsPoolingComponent will only be defined if at least one input was defined.
The output of StatisticsPoolingComponent will only be defined if at least one
input was defined.
*/
class StatisticsPoolingComponent: public Component {
public:
Expand Down Expand Up @@ -396,6 +397,11 @@ class StatisticsPoolingComponent: public Component {
const std::vector<Index> &output_indexes,
bool need_backprop) const;

// Used in computing the 'real' context of networks involving this component;
// with the default value of false, the left/right context will always appear
// to be 0.
void SetRequireDirectInput(bool b) { require_direct_input_ = b; }

private:
// Checks that the parameters are valid.
void Check() const;
Expand All @@ -411,6 +417,13 @@ class StatisticsPoolingComponent: public Component {
int32 num_log_count_features_;
bool output_stddevs_;
BaseFloat variance_floor_;
// If require_direct_input_ is set to true, in order for a particular 't'
// value to be available at the output of this component, it will require that
// 't' value to be computable at the input. This is used in computing the
// "real" left/right context of the network, but this member isn't currently
// written to disk and will default to false when read.
bool require_direct_input_;

};

class StatisticsPoolingComponentPrecomputedIndexes:
Expand Down
15 changes: 12 additions & 3 deletions src/nnet3/nnet-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ void SetNnetAsGradient(Nnet *nnet) {
}
}

void SetRequireDirectInput(bool b, Nnet *nnet) {
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
if (dynamic_cast<StatisticsPoolingComponent*>(comp) != NULL)
dynamic_cast<StatisticsPoolingComponent*>(comp)->SetRequireDirectInput(b);
}
}


void ScaleNnet(BaseFloat scale, Nnet *nnet) {
if (scale == 1.0) return;
else {
Expand Down Expand Up @@ -724,7 +733,7 @@ class SvdApplier {
<< " components to FixedAffineComponent.";
}

// This function finds the minimum index of
// This function finds the minimum index of
// the Descending order sorted [input_vector],
// over a range of indices from [lower] to [upper] index,
// for which the sum of elements upto the found min. index is greater
Expand All @@ -743,7 +752,7 @@ class SvdApplier {
}
return (i+1);
}

// Here we perform SVD based refactorig of an input Affine component.
// After applying SVD , we sort the Singularity values in descending order,
// and take the subset of values which contribute to energy_threshold times
Expand Down Expand Up @@ -777,7 +786,7 @@ class SvdApplier {
if (energy_threshold_ > 0) {
BaseFloat min_singular_sum = energy_threshold_ * s2_sum_orig;
bottleneck_dim_ = GetReducedDimension(s2, 0, s2.Dim()-1, min_singular_sum);
}
}
SubVector<BaseFloat> this_part(s2, 0, bottleneck_dim_);
BaseFloat s2_sum_reduced = this_part.Sum();
BaseFloat shrinkage_ratio =
Expand Down
8 changes: 8 additions & 0 deletions src/nnet3/nnet-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ void ScaleNnet(BaseFloat scale, Nnet *nnet);
/// learning_rate_ to 1 for each UpdatableComponent in nnet
void SetNnetAsGradient(Nnet *nnet);


/// Calls the corresponding function in any component of type
/// StatisticsPoolingComponent; used as a way to compute the 'real' left-right
/// context of networks including SatisticsPoolingComponent, which will give you
/// the minimum chunk size they can consume.
void SetRequireDirectInput(bool b, Nnet *nnet);


/// Does *dest += alpha * src (affects nnet parameters and
/// stored stats).
void AddNnet(const Nnet &src, BaseFloat alpha, Nnet *dest);
Expand Down
2 changes: 2 additions & 0 deletions src/nnet3bin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \
nnet3-discriminative-subset-egs nnet3-get-egs-simple \
nnet3-discriminative-compute-from-egs nnet3-latgen-faster-looped \
nnet3-egs-augment-image nnet3-xvector-get-egs nnet3-xvector-compute \
nnet3-xvector-compute-batched \
nnet3-latgen-grammar nnet3-compute-batch nnet3-latgen-faster-batch \
cuda-gpu-available cuda-compiled

Expand All @@ -36,4 +37,5 @@ ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \
../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \
../base/kaldi-base.a


include ../makefiles/default_rules.mk
Loading