diff --git a/src/nnet2/nnet-component.cc b/src/nnet2/nnet-component.cc index 6a18d309012..75d0fe3cde1 100644 --- a/src/nnet2/nnet-component.cc +++ b/src/nnet2/nnet-component.cc @@ -1407,6 +1407,20 @@ Component *AffineComponent::CollapseWithNext( return ans; } +Component *AffineComponent::CollapseWithNext( + const FixedScaleComponent &next_component) const { + KALDI_ASSERT(this->OutputDim() == next_component.InputDim()); + AffineComponent *ans = + dynamic_cast(this->Copy()); + KALDI_ASSERT(ans != NULL); + ans->linear_params_.MulRowsVec(next_component.scales_); + ans->bias_params_.MulElements(next_component.scales_); + + return ans; +} + + + Component *AffineComponent::CollapseWithPrevious( const FixedAffineComponent &prev_component) const { // If at least one was non-updatable, make the whole non-updatable. diff --git a/src/nnet2/nnet-component.h b/src/nnet2/nnet-component.h index a8519a0a6fb..44a19d28b2d 100644 --- a/src/nnet2/nnet-component.h +++ b/src/nnet2/nnet-component.h @@ -707,8 +707,9 @@ class ScaleComponent: public Component { -class SumGroupComponent; // Forward declaration. -class AffineComponent; // Forward declaration. +class SumGroupComponent; // Forward declaration. +class AffineComponent; // Forward declaration. +class FixedScaleComponent; // Forward declaration. class SoftmaxComponent: public NonlinearComponent { public: @@ -803,6 +804,7 @@ class AffineComponent: public UpdatableComponent { // FixedLinearComponent yet. Component *CollapseWithNext(const AffineComponent &next) const ; Component *CollapseWithNext(const FixedAffineComponent &next) const; + Component *CollapseWithNext(const FixedScaleComponent &next) const; Component *CollapseWithPrevious(const FixedAffineComponent &prev) const; virtual std::string Info() const; @@ -1473,6 +1475,7 @@ class FixedScaleComponent: public Component { virtual void Write(std::ostream &os, bool binary) const; protected: + friend class AffineComponent; // necessary for collapse CuVector scales_; KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent); }; diff --git a/src/nnet2/nnet-nnet.cc b/src/nnet2/nnet-nnet.cc index 31a5e5f08b6..fbf4dbb0678 100644 --- a/src/nnet2/nnet-nnet.cc +++ b/src/nnet2/nnet-nnet.cc @@ -51,7 +51,6 @@ int32 Nnet::LeftContext() const { // non-negative left context. In addition, the NnetExample also stores data // left context as positive integer. To be compatible with these other classes // Nnet::LeftContext() returns a non-negative left context. - } int32 Nnet::RightContext() const { @@ -66,8 +65,8 @@ int32 Nnet::RightContext() const { void Nnet::ComputeChunkInfo(int32 input_chunk_size, int32 num_chunks, std::vector *chunk_info_out) const { - // First compute the output-chunk indices for the last component in the network. - // we assume that the numbering of the input starts from zero. + // First compute the output-chunk indices for the last component in the + // network. we assume that the numbering of the input starts from zero. int32 output_chunk_size = input_chunk_size - LeftContext() - RightContext(); KALDI_ASSERT(output_chunk_size > 0); std::vector current_output_inds; @@ -88,7 +87,7 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size, for (int32 i = NumComponents() - 1; i >= 0; i--) { std::vector current_context = GetComponent(i).Context(); std::set current_input_ind_set; - for (size_t j = 0; j < current_context.size(); j++) + for (size_t j = 0; j < current_context.size(); j++) for (size_t k = 0; k < current_output_inds.size(); k++) current_input_ind_set.insert(current_context[j] + current_output_inds[k]); @@ -137,7 +136,6 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size, (*chunk_info_out)[i].Check(); // (*chunk_info_out)[i].ToString(); } - } const Component& Nnet::GetComponent(int32 component) const { @@ -359,7 +357,8 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { KALDI_ASSERT(new_num_pdfs > 0); KALDI_ASSERT(NumComponents() > 2); int32 nc = NumComponents(); - SumGroupComponent *sgc = dynamic_cast(components_[nc - 1]); + SumGroupComponent *sgc = + dynamic_cast(components_[nc - 1]); if (sgc != NULL) { // Remove it. We'll resize things later. delete sgc; @@ -367,21 +366,47 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { components_.begin() + nc); nc--; } - SoftmaxComponent *sc; if ((sc = dynamic_cast(components_[nc - 1])) == NULL) KALDI_ERR << "Expected last component to be SoftmaxComponent."; + // check if nc-1 has a FixedScaleComponent + bool has_fixed_scale_component = false; + int32 fixed_scale_component_index = -1; + int32 final_affine_component_index = nc - 2; + int32 softmax_component_index = nc - 1; + FixedScaleComponent *fsc = + dynamic_cast( + components_[final_affine_component_index]); + if (fsc != NULL) { + has_fixed_scale_component = true; + fixed_scale_component_index = nc - 2; + final_affine_component_index = nc - 3; + } // note: it could be child class of AffineComponent. - AffineComponent *ac = dynamic_cast(components_[nc - 2]); + AffineComponent *ac = dynamic_cast( + components_[final_affine_component_index]); if (ac == NULL) KALDI_ERR << "Network doesn't have expected structure (didn't find final " << "AffineComponent)."; - + if (has_fixed_scale_component) { + // collapse the fixed_scale_component with the affine_component before it + AffineComponent *ac_new = + dynamic_cast(ac->CollapseWithNext(*fsc)); + KALDI_ASSERT(ac_new != NULL); + delete fsc; + delete ac; + components_.erase(components_.begin() + fixed_scale_component_index, + components_.begin() + (fixed_scale_component_index + 1)); + components_[final_affine_component_index] = ac_new; + ac = ac_new; + softmax_component_index = softmax_component_index - 1; + } ac->Resize(ac->InputDim(), new_num_pdfs); // Remove the softmax component, and replace it with a new one - delete components_[nc - 1]; - components_[nc - 1] = new SoftmaxComponent(new_num_pdfs); + delete components_[softmax_component_index]; + components_[softmax_component_index] = new SoftmaxComponent(new_num_pdfs); + this->SetIndexes(); // used for debugging this->Check(); } @@ -655,8 +680,9 @@ void Nnet::Vectorize(VectorBase *params) const { KALDI_ASSERT(offset == GetParameterDim()); } -void Nnet::ResetGenerators() { // resets random-number generators for all random - // components. +void Nnet::ResetGenerators() { + // resets random-number generators for all random + // components. for (int32 c = 0; c < NumComponents(); c++) { RandomComponent *rc = dynamic_cast( &(GetComponent(c)));