Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 49 additions & 33 deletions src/nnet3/nnet-computation-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ int32 ComputationStepsComputer::AddStep(const std::vector<Cindex> &cindexes,
*out_iter = cindex_id;
if (added) {
KALDI_ASSERT(cindex_id == static_cast<int32>(locations_->size()));
locations_->resize(cindex_id + 1, std::pair<int32, int32>(-1, -1));
locations_->resize(cindex_id + 1);
locations_->back().first = step_index;
locations_->back().second = row_index;
locations = &((*locations_)[0]); // in case it was reallocated
Expand Down Expand Up @@ -1867,52 +1867,68 @@ void ComputationStepsComputer::ProcessDimRangeSubPhase(
ConvertToCindexIds(input_cindexes, &input_cindex_ids);
std::vector<std::pair<int32, int32> > locations;
ConvertToLocations(input_cindex_ids, &locations);

// get a list of the source step indexes (corresponding to computations for the
// source component-node)
std::unordered_set<int32> source_step_indexes;
std::sort(locations.begin(), locations.end());
KALDI_ASSERT(!locations.empty());
std::vector<std::pair<int32, int32> >::const_iterator
locations_iter = locations.begin(),
locations_end = locations.end();

// 'cur_source_step_index' is just an optimization to prevent unnecessary
// unordered_set inserts.
int32 cur_source_step_index = -1;
for (; locations_iter != locations_end; ++locations_iter) {
int32 source_step_index = locations_iter->first;
if (source_step_index != cur_source_step_index) {
cur_source_step_index = source_step_index;
source_step_indexes.insert(cur_source_step_index);
// Each unique .first number in locations (i.e. each source step, and they
// will all correspond to component-output or input steps) will generate one
// 'step' of type kDimRange. Because dim-range nodes must be contiguous
// ranges of a source step (since they are represented as sub-matrices), for
// each source step we work out the first and last row-index (i.e. first and
// last .second member of locations) and use that to reconstruct the range.

// each element of 'steps' will be (source_step, (begin_row, end_row)) so that
// the source of the dim-range node is indexes begin_row ... end_row-1 in that
// source step.
std::vector<std::pair<int32, std::pair<int32, int32> > > steps;

int32 cur_source_step = locations_iter->first,
cur_row_begin = locations_iter->second,
cur_row_end = cur_row_begin + 1;
while (1) {
++locations_iter;
if (locations_iter == locations_end ||
locations_iter->first != cur_source_step) {
// we reached the end of a run of the same step.
std::pair<int32, std::pair<int32, int32> > this_step;
this_step.first = cur_source_step;
this_step.second.first = cur_row_begin;
this_step.second.second = cur_row_end;
steps.push_back(this_step);
if (locations_iter != locations_end) {
cur_source_step = locations_iter->first;
cur_row_begin = locations_iter->second;
cur_row_end = cur_row_begin + 1;
} else {
break;
}
} else {
cur_row_end = locations_iter->second + 1;
}
}

std::unordered_set<int32>::const_iterator
source_step_iter = source_step_indexes.begin(),
source_step_end = source_step_indexes.end();
// iterating over the indexes of the source steps.
for (; source_step_iter != source_step_end; ++source_step_iter) {
int32 source_step_index = *source_step_iter;
std::pair<int32, int32> p(source_step_index, dim_range_node);
if (dim_range_nodes_.count(p) > 0) {
// We don't need to do anything; a dim-range node already exists for this
// step and this node index.
continue;
}
dim_range_nodes_.insert(p);
const std::vector<int32> &source_step = (*steps_)[source_step_index];
// 'cindexes' will be the cindexes of the new step that we're going to add.
for (size_t i = 0; i < steps.size(); i++) {
// iterating over different source steps, although normally
// there will be just one.
int32 source_step = steps[i].first,
row_begin = steps[i].second.first,
row_end = steps[i].second.second;
// 'source' is just the elements of the source step that we're consuming.
std::vector<int32> source((*steps_)[source_step].begin() + row_begin,
(*steps_)[source_step].begin() + row_end);
std::vector<Cindex> cindexes;
ConvertToCindexes(source_step, &cindexes);
ConvertToCindexes(source, &cindexes);
std::vector<Cindex>::iterator iter = cindexes.begin(),
end = cindexes.end();
for (; iter != end; ++iter)
iter->first = dim_range_node;
bool add_if_absent = true;
// this add_if_absent says, even if cindexes were not in the graph,
// add them. This is possible; the step will contain all cindexes for the
// input step, even if they won't be needed. (This is costless; it's just
// setting up a sub-matrix).
// add them. This is possible in principle; it's to satisfy the
// requirement that DimRangeNodes be implemented as contiguous ranges
// of rows of component nodes or input nodes.
AddStep(cindexes, add_if_absent);
}
}
Expand Down
11 changes: 1 addition & 10 deletions src/nnet3/nnet-computation-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class ComputationStepsComputer {
/// (step-index, index-into-step), so that for any cindex_id c,
/// (*steps)[locations[c].first][locations[c].second] == c.
/// It's possible in principle if there are non-simple
/// Components, that for nodes corresponding to component-input
/// Components, that for node corresponding to component-input
/// descriptors, a cindex might be present in more than one step,
/// so it doesn't follow that if (*steps)[i][j] == c, then
/// locations[c] == (i,j).
Expand Down Expand Up @@ -547,15 +547,6 @@ class ComputationStepsComputer {
/// (*steps_)[i][j] == c. This is also an output (we get the pointer in
/// the constructor).
std::vector<std::pair<int32, int32> > *locations_;


/// dim_range_nodes_ is used when allocating steps for nodes of type kDimRangeNode.
/// This is a set of (source_step, dim_range_node_index),
/// where source_step is the step in which we computed of the input
/// of the dim-range node (this step will be for a node of type kComponentNode).
/// This just tells us whether we've already added a particular dim-range node
/// for this step, so we know whether we need to add it again.
std::unordered_set<std::pair<int32, int32>, PairHasher<int32> > dim_range_nodes_;
};


Expand Down