Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/nnet3/nnet-computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,17 @@ struct NnetComputation {
// These are owned here.
std::vector<PrecomputedIndexesInfo> component_precomputed_indexes;

// used in kAddRows, kAddToRows, kCopyRows, kCopyToRows. contains row-indexes.
// Used in commands kAddRows, kAddToRows, kCopyRows, which
// contain indexes into this data-member.
// Each vector<int32> is a vector of row-indexes (with -1 usually treated as
// a special case meaning "don't do anything for this row" for add
// commands, or "use zero" for copy commands.
std::vector<std::vector<int32> > indexes;

// used kAddRowsMulti, kAddToRowsMulti, kCopyRowsMulti, kCopyToRowsMulti.
// contains pairs (sub-matrix index, row index)- or (-1,-1) meaning don't
// do anything for this row.
// Used in commands kAddRowsMulti, kAddToRowsMulti, kCopyRowsMulti and
// kCopyToRowsMulti. Contains pairs (sub-matrix index, row index)- or the
// special pair (-1,-1) meaning "don't do anything for this row" for add
// commands, or "use zero" for copy commands.
std::vector<std::vector<std::pair<int32,int32> > > indexes_multi;


Expand Down
322 changes: 322 additions & 0 deletions src/nnet3/nnet-optimize-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,328 @@ bool SnipRowOps(NnetComputation *computation) {



// This class implements the internals of the function SplitRowOps() which is
// declared in nnet-optimize-utils.h.
class RowOpsSplitter {
public:
RowOpsSplitter(NnetComputation *computation): computation_(computation) { }

// Attempts to perform the optimization. Returns true if it made any change
// to the computation.
bool Split() {
return SplitIndexes() && SplitCommands();
}

private:

// This function sets up split_info_, which describes how we can split up
// the vectors that are elements of computation_->indexes_multi.
// It will return true if it successfully split at least one of those
// vectors, and false otherwise.
bool SplitIndexes();

// This function modifies the commands in the computation. It returns
// true if it made any change.
bool SplitCommands();


// This function attempts to optimize the command in
// computation_->commands[command_index]. It returns true if it made any
// change. If we are going to have to insert an extra command into the
// computation, this function will append an element to new_commands_.
bool SplitCommand(int32 command_index);

// Below, define a multi-index as an element of NnetComputation::indexes_multi,
// for example,
// const std::vector<std::pair<int32,int32> > &multi_index = computation_->indexes_multi[1];
// It is a list of pairs.

// This struct appears as an element of the list inside MultiIndexSplitInfo.
// It helps us describe how we can split up a multi-index (a list of pairs)
// into a sequence of ranges where the .first value is constant across the
// range.
struct SingleSplitInfo {
// 'offset' is the index into the vector of pairs that forms the
// start of this range. In the example where we are splitting up
// ((10,2), (10,3), (10,4), (15,3), (15,5), (15,7))
// there would be two instances of struct SingleSplitInfo, with
// offset = 0 and offset = 3.
int32 offset;
// 'size' is the number of pairs in this range; in the example
// above, both 'size' elements would be 3.
int32 size;
// first_value is the value of the .first index throughout this range; in
// the example above, it would be 10 and 15 respectively. It represents a
// submatrix index.
int32 first_value;

// initial_second_value is the minimum value of .second for any element in
// this range: it would be 2 and 3 respectively in the example above.
int32 min_second_value;

// second_value_range is the highest value of .second for any element in
// this range, plus one, minus min_second_value. (It's the number of rows
// in the other submatrix of the operation).
int32 second_value_range;

// If the .second values in the range are consecutive then
// 'second_value_offsets' will be empty. Otherwise it will
// be a vector of size 'size', containing numbers in the
// range 0 ... second_value_range - 1, such that
// min_second_value + second_value_offsets[i] gives
// the .second value at the corresponding position in the range.
// In the second range of the example above, the range
// consisting of ((15,3), (15,5), (15,7)), 'second_value_offsets
// would be the vector (0, 2, 4).
std::vector<int32> second_value_offsets;
};

// An instance of the struct MultiIndexSplitInfo will be created for each multi-index,
// i.e. for each element of NnetComputation::indexes_multi.
struct MultiIndexSplitInfo {
// If we can split this multi-index into at most two ranges, this
// vector will be nonempty; otherwise it will be empty.
std::vector<SingleSplitInfo> splits;
};

// GetSplitInfo() attempts to take a range of a
// std::vector<std::pair<int32, int32> >, as represented by begin and end
// iterators, and to extract its information into an object of type
// SingleSplitInfo. (all except for the .offset member, which will have
// been set by calling code).
// It return true if successful, and false otherwise. The only reasons that
// it might return false are that the range contains -1's or does not contain
// all-identical .first members).
bool GetSplitInfo(std::vector<std::pair<int32, int32> >::const_iterator begin,
std::vector<std::pair<int32, int32> >::const_iterator end,
SingleSplitInfo *info);

// computation_ is the computation that we are modifying.
NnetComputation *computation_;
// split_info_ will contain information about how we can split up the members
// of computation_->indexes_multi into ranges.
std::vector<MultiIndexSplitInfo> split_info_;
// The following is a list of additional commands that we are going to insert
// into computation_, of the form (command-index, command) where command-index
// is a command index just before which we will insert the new command.
// (this is the format accepted by the function InsertCommands()).
std::vector<std::pair<int32, NnetComputation::Command> > new_commands_;

};


bool RowOpsSplitter::GetSplitInfo(
std::vector<std::pair<int32, int32> >::const_iterator begin,
std::vector<std::pair<int32, int32> >::const_iterator end,
SingleSplitInfo *info) {
// max_size_ratio must be > 1.0, and could in principle be a float. It is
// there to prevent us from making changes to the computation which would end
// up wastefully launching too many kernels that would do nothing.
const int32 max_size_ratio = 2;

int32 size = end - begin;
KALDI_ASSERT(size != 0);
int32 first = begin->first;
if (first < 0)
return false;
info->size = size;
info->first_value = first;
int32 initial_second_value = begin->second,
min_second_value = initial_second_value,
max_second_value = initial_second_value;
info->second_value_offsets.resize(size);
bool is_consecutive = true;
for (int32 i = 0; i < size; i++) {
int32 second = begin[i].second;
if (begin[i].first != first || second < 0) return false;
info->second_value_offsets[i] = second;
if (second != initial_second_value + i)
is_consecutive = false;
if (second < min_second_value) min_second_value = second;
if (second > max_second_value) max_second_value = second;
}
info->min_second_value = min_second_value;
info->second_value_range = max_second_value + 1 - min_second_value;
if (info->second_value_range > size * max_size_ratio)
return false;
if (is_consecutive) {
info->second_value_offsets.clear();
} else {
for (int32 i = 0; i < size; i++)
info->second_value_offsets[i] -= min_second_value;
}
return true;
}


bool RowOpsSplitter::SplitIndexes() {
bool ans = false;
int32 num_indexes_multi = computation_->indexes_multi.size();
split_info_.resize(num_indexes_multi);
for (int32 i = 0; i < num_indexes_multi; i++) {
const std::vector<std::pair<int32,int32> > &multi_index =
computation_->indexes_multi[i];
MultiIndexSplitInfo &split_info = split_info_[i];

int32 num_pairs = multi_index.size();
KALDI_ASSERT(num_pairs > 0);
// 'split_point' will be set to the first index j for which
// multi_index[j-1].first != multi_index[j].first, or -1
// if no such j exists.
int32 split_point = -1, initial_first = multi_index[0].first;
for (int32 j = 1; j < num_pairs; j++) {
if (multi_index[j].first != initial_first) {
split_point = j;
break;
}
}
if (split_point == -1) {
split_info.splits.resize(1);
split_info.splits[0].offset = 0;
if (!GetSplitInfo(multi_index.begin(), multi_index.end(),
&(split_info.splits[0]))) {
split_info.splits.clear();
} else {
ans = true;
}
} else {
split_info.splits.resize(2);
split_info.splits[0].offset = 0;
split_info.splits[1].offset = split_point;

std::vector<std::pair<int32,int32> >::const_iterator mid_iter =
multi_index.begin() + split_point;
if (!GetSplitInfo(multi_index.begin(), mid_iter,
&(split_info.splits[0])) ||
!GetSplitInfo(mid_iter, multi_index.end(),
&(split_info.splits[1]))) {
split_info.splits.clear();
} else {
ans = true;
}
}
}
return ans;
}

bool RowOpsSplitter::SplitCommand(int32 c) {
NnetComputation::Command &command = computation_->commands[c];
CommandType command_type = command.command_type;
// For commands that are not of the following four types, return false: we
// won't be changing these commands.
switch (command_type) {
case kAddRowsMulti: case kCopyRowsMulti:
case kAddToRowsMulti: case kCopyToRowsMulti: break;
default: return false;
}
int32 indexes_multi_index = command.arg2;
KALDI_ASSERT(indexes_multi_index <
static_cast<int32>(split_info_.size()));
const MultiIndexSplitInfo &split_info = split_info_[indexes_multi_index];
if (split_info.splits.empty())
return false; // these indexes couldn't be split: e.g. they contained more
// than two distinct .first elements, or there were other
// reasons.

// we'll be splitting the command into either one or two pieces.
std::vector<NnetComputation::Command> split_commands(
split_info.splits.size());
for (size_t i = 0; i < split_info.splits.size(); i++) {
const SingleSplitInfo &split = split_info.splits[i];
NnetComputation::Command &command_out = split_commands[i];
command_out.alpha = command.alpha;
command_out.arg1 = computation_->NewSubMatrix(
command.arg1, split.offset, split.size, 0, -1);
command_out.arg2 = computation_->NewSubMatrix(
split.first_value, split.min_second_value,
split.second_value_range, 0, -1);

if (split.second_value_offsets.empty()) {
// The .second elements are consecutive.
switch (command_type) {
case kAddRowsMulti:
command_out.command_type = kMatrixAdd;
break;
case kCopyRowsMulti:
command_out.command_type = kMatrixCopy;
break;
case kAddToRowsMulti:
command_out.command_type = kMatrixAdd;
std::swap(command_out.arg1, command_out.arg2);
break;
case kCopyToRowsMulti:
command_out.command_type = kMatrixCopy;
std::swap(command_out.arg1, command_out.arg2);
break;
default: // will never be reached.
break;
}
} else {
// Indexes are not consecutive: it needs to be a kAddRows or kCopyRows
// command.
command_out.arg3 = computation_->indexes.size();
switch (command_type) {
case kAddRowsMulti: case kCopyRowsMulti: {
command_out.command_type = (command_type == kAddRowsMulti ?
kAddRows : kCopyRows);
computation_->indexes.push_back(split.second_value_offsets);
break;
}
case kCopyToRowsMulti: {
// We can't operate on this command because of what would happen
// with values of 'indexes' (see the variable in the block for
// kAddToRowsMulti) which were -1. Rows of the output would be
// set to zero, which is not the behavior we want here; we'd want
// them to be unaffected.
return false;
}
case kAddToRowsMulti: {
command_out.command_type = kAddRows;
std::swap(command_out.arg1, command_out.arg2);
// invert the indexes.
std::vector<int32> indexes(split.second_value_range, -1);
for (int32 i = 0; i < split.size; i++) {
// the following assert should always succeed because the
// AddToRowsMulti and CopyToRowsMulti should never have
// duplicate destinations in their indexes.
KALDI_ASSERT(indexes[split.second_value_offsets[i]] >= 0);
indexes[split.second_value_offsets[i]] = i;
}
computation_->indexes.push_back(indexes);
break;
}
default:
KALDI_ERR << "Code error: un-handled case.";
}
}
}
command = split_commands[0];
// note: for now, split_commands.size() will be 1 or 2.
for (size_t i = 1; i < split_commands.size(); i++) {
new_commands_.resize(new_commands_.size() + 1);
// we'll want to insert this command right after command c,
// which is the same as just before command c + 1.
new_commands_.back().first = c + 1;
new_commands_.back().second = split_commands[i];
}
return true; // We made a change.
}

bool RowOpsSplitter::SplitCommands() {
bool ans = false;
int32 num_commands = computation_->commands.size();
for (int32 c = 0; c < num_commands; c++)
if (SplitCommand(c))
ans = true;
if (!new_commands_.empty())
InsertCommands(&new_commands_, computation_);
return ans;
}

bool SplitRowOps(NnetComputation *computation) {
RowOpsSplitter splitter(computation);
return splitter.Split();
}


/*
Expand Down
Loading