diff --git a/src/nnet3/nnet-computation-graph.cc b/src/nnet3/nnet-computation-graph.cc index 9c84115d406..bded9e84b2f 100644 --- a/src/nnet3/nnet-computation-graph.cc +++ b/src/nnet3/nnet-computation-graph.cc @@ -1643,7 +1643,7 @@ int32 ComputationStepsComputer::AddStep(const std::vector &cindexes, *out_iter = cindex_id; if (added) { KALDI_ASSERT(cindex_id == static_cast(locations_->size())); - locations_->resize(cindex_id + 1, std::pair(-1, -1)); + locations_->resize(cindex_id + 1); locations_->back().first = step_index; locations_->back().second = row_index; locations = &((*locations_)[0]); // in case it was reallocated @@ -1867,52 +1867,68 @@ void ComputationStepsComputer::ProcessDimRangeSubPhase( ConvertToCindexIds(input_cindexes, &input_cindex_ids); std::vector > locations; ConvertToLocations(input_cindex_ids, &locations); - - // get a list of the source step indexes (corresponding to computations for the - // source component-node) - std::unordered_set source_step_indexes; + std::sort(locations.begin(), locations.end()); KALDI_ASSERT(!locations.empty()); std::vector >::const_iterator locations_iter = locations.begin(), locations_end = locations.end(); - - // 'cur_source_step_index' is just an optimization to prevent unnecessary - // unordered_set inserts. - int32 cur_source_step_index = -1; - for (; locations_iter != locations_end; ++locations_iter) { - int32 source_step_index = locations_iter->first; - if (source_step_index != cur_source_step_index) { - cur_source_step_index = source_step_index; - source_step_indexes.insert(cur_source_step_index); + // Each unique .first number in locations (i.e. each source step, and they + // will all correspond to component-output or input steps) will generate one + // 'step' of type kDimRange. Because dim-range nodes must be contiguous + // ranges of a source step (since they are represented as sub-matrices), for + // each source step we work out the first and last row-index (i.e. first and + // last .second member of locations) and use that to reconstruct the range. + + // each element of 'steps' will be (source_step, (begin_row, end_row)) so that + // the source of the dim-range node is indexes begin_row ... end_row-1 in that + // source step. + std::vector > > steps; + + int32 cur_source_step = locations_iter->first, + cur_row_begin = locations_iter->second, + cur_row_end = cur_row_begin + 1; + while (1) { + ++locations_iter; + if (locations_iter == locations_end || + locations_iter->first != cur_source_step) { + // we reached the end of a run of the same step. + std::pair > this_step; + this_step.first = cur_source_step; + this_step.second.first = cur_row_begin; + this_step.second.second = cur_row_end; + steps.push_back(this_step); + if (locations_iter != locations_end) { + cur_source_step = locations_iter->first; + cur_row_begin = locations_iter->second; + cur_row_end = cur_row_begin + 1; + } else { + break; + } + } else { + cur_row_end = locations_iter->second + 1; } } - std::unordered_set::const_iterator - source_step_iter = source_step_indexes.begin(), - source_step_end = source_step_indexes.end(); - // iterating over the indexes of the source steps. - for (; source_step_iter != source_step_end; ++source_step_iter) { - int32 source_step_index = *source_step_iter; - std::pair p(source_step_index, dim_range_node); - if (dim_range_nodes_.count(p) > 0) { - // We don't need to do anything; a dim-range node already exists for this - // step and this node index. - continue; - } - dim_range_nodes_.insert(p); - const std::vector &source_step = (*steps_)[source_step_index]; - // 'cindexes' will be the cindexes of the new step that we're going to add. + for (size_t i = 0; i < steps.size(); i++) { + // iterating over different source steps, although normally + // there will be just one. + int32 source_step = steps[i].first, + row_begin = steps[i].second.first, + row_end = steps[i].second.second; + // 'source' is just the elements of the source step that we're consuming. + std::vector source((*steps_)[source_step].begin() + row_begin, + (*steps_)[source_step].begin() + row_end); std::vector cindexes; - ConvertToCindexes(source_step, &cindexes); + ConvertToCindexes(source, &cindexes); std::vector::iterator iter = cindexes.begin(), end = cindexes.end(); for (; iter != end; ++iter) iter->first = dim_range_node; bool add_if_absent = true; // this add_if_absent says, even if cindexes were not in the graph, - // add them. This is possible; the step will contain all cindexes for the - // input step, even if they won't be needed. (This is costless; it's just - // setting up a sub-matrix). + // add them. This is possible in principle; it's to satisfy the + // requirement that DimRangeNodes be implemented as contiguous ranges + // of rows of component nodes or input nodes. AddStep(cindexes, add_if_absent); } } diff --git a/src/nnet3/nnet-computation-graph.h b/src/nnet3/nnet-computation-graph.h index c0662756502..7999f2208ad 100644 --- a/src/nnet3/nnet-computation-graph.h +++ b/src/nnet3/nnet-computation-graph.h @@ -439,7 +439,7 @@ class ComputationStepsComputer { /// (step-index, index-into-step), so that for any cindex_id c, /// (*steps)[locations[c].first][locations[c].second] == c. /// It's possible in principle if there are non-simple - /// Components, that for nodes corresponding to component-input + /// Components, that for node corresponding to component-input /// descriptors, a cindex might be present in more than one step, /// so it doesn't follow that if (*steps)[i][j] == c, then /// locations[c] == (i,j). @@ -547,15 +547,6 @@ class ComputationStepsComputer { /// (*steps_)[i][j] == c. This is also an output (we get the pointer in /// the constructor). std::vector > *locations_; - - - /// dim_range_nodes_ is used when allocating steps for nodes of type kDimRangeNode. - /// This is a set of (source_step, dim_range_node_index), - /// where source_step is the step in which we computed of the input - /// of the dim-range node (this step will be for a node of type kComponentNode). - /// This just tells us whether we've already added a particular dim-range node - /// for this step, so we know whether we need to add it again. - std::unordered_set, PairHasher > dim_range_nodes_; };