Skip to content

Commit

Permalink
[Relay][Pass] Update SimplifyTranspose to correctly simplify rank cha…
Browse files Browse the repository at this point in the history
…nging layout transforms (#7807)
  • Loading branch information
csullivan authored May 4, 2021
1 parent 284faf2 commit 396a09e
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 31 deletions.
175 changes: 144 additions & 31 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <tvm/runtime/logging.h>

#include <limits>
#include <memory>
#include <string>
#include <utility>

#include "../op/tensor/transform.h"
Expand Down Expand Up @@ -117,36 +119,20 @@ class SimplifyTranspose : public DFPatternRewrite {

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
// Helper function to get the axes from call node attribute
auto get_axes_from_call = [](const Call trans_call, int ndim) {
std::vector<int> attr_axes;
if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i];
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
auto x = node_map[x_][0];

Call trans_call = Downcast<Call>(post);

// Try to fuse any rank changing layout transformations
if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) {
if (auto attr = layout_trans.value()->attrs.as<LayoutTransformAttrs>()) {
// Prune any trivial layout transformation
if (attr->src_layout == attr->dst_layout) {
return x;
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(trans_call->op)->name;
}
return std::move(attr_axes);
};

auto x = node_map[x_][0];
return layout_trans.value();
}

// Initialize axes
int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
Expand All @@ -157,10 +143,9 @@ class SimplifyTranspose : public DFPatternRewrite {

// Collect axes changes from the matched pattern, including two consecutive transposes.
std::vector<std::vector<int>> interm_axes;
Call trans_call = Downcast<Call>(post);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
trans_call = Downcast<Call>(trans_call->args[0]);
interm_axes.push_back(get_axes_from_call(trans_call, ndim));
interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));

// Calculate the final axes in reverse order (from root to output)
auto it = interm_axes.rbegin();
Expand Down Expand Up @@ -190,6 +175,134 @@ class SimplifyTranspose : public DFPatternRewrite {
return x;
}

String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
std::string new_layout{};
std::string old_layout{layout};
ICHECK_EQ(axes_order.size(), layout.size())
<< "Number of axes must match the number of named axes in the layout to permute: length("
<< old_layout << ") != " << axes_order.size();
std::stringstream order;
for (auto axis : axes_order) {
new_layout += old_layout[axis];
order << axis << ", ";
}
DLOG(INFO) << "Using transpose axes order {" << order.str()
<< "} to permute layout: " << old_layout << " to " << new_layout;
return new_layout;
}

struct RankChangingLayoutDescriptor {
Layout src_layout;
Layout dst_layout;
// Either a rank changing layout transform or a transpose
Call other_transform;
};

std::unique_ptr<RankChangingLayoutDescriptor> GetRankChangeDescriptor(const Call& call) const {
std::unique_ptr<RankChangingLayoutDescriptor> desc{nullptr};
if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
if (attr->src_layout.length() != attr->dst_layout.length()) {
desc = std::make_unique<RankChangingLayoutDescriptor>();
desc->src_layout = Layout(attr->src_layout);
desc->dst_layout = Layout(attr->dst_layout);
desc->other_transform = Downcast<Call>(call->args[0]);
}
}
if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) {
if (attr->src_layout.length() != attr->dst_layout.length()) {
if (!desc) {
desc = std::make_unique<RankChangingLayoutDescriptor>();
desc->src_layout = Layout(attr->src_layout);
desc->dst_layout = Layout(attr->dst_layout);
desc->other_transform = call;
} else {
ICHECK(desc->src_layout->name == attr->dst_layout)
<< "Back-to-back layout transforms must have the same intermediate layout: "
<< desc->src_layout->name << " != " << attr->dst_layout;
desc->src_layout = Layout(attr->src_layout);
}
}
}
return desc;
}

/*
* \brief Fuse call and it's argument into a single layout_transform operator
* when either call or it's argument is a rang changing layout_transform, e.g.,
*
* Simplify
*
* [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c]
*
* to,
*
* [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c].
*
* \param The input expression to the matched pattern
* \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops
*/
Optional<Call> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
// Check to see if either the first or second call in matched pattern
// is a rank changing layout transform. If so, return a descriptor containing
// the layouts and any additional transpose or layout transform op.
auto desc = GetRankChangeDescriptor(call);
if (desc == nullptr) {
// No rank changing layout transform
return Optional<Call>{nullptr};
}

Optional<Expr> output_layout_trans;
// Fuse a rank increasing layout transform and a preceeding transpose
if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) {
auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size());
// Calculate the reverse axis order and apply to the source layout
std::vector<int> inverse(axes.size());
for (size_t i = 0; i < axes.size(); i++) {
inverse[axes[i]] = i;
}
String new_layout = PermuteLayout(desc->src_layout->name, inverse);
output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name);
// Fuse a rank descreasing layout transform followed by a transpose
} else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) {
auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size());
String new_layout = PermuteLayout(desc->dst_layout->name, axes);
output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout);
// Fuse two back-to-back layout transformations which change rank
} else if (desc->other_transform->attrs.as<LayoutTransformAttrs>()) {
output_layout_trans =
MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name);
}
return Downcast<Call>(output_layout_trans);
}

std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
std::vector<int> attr_axes;
if (auto attr = call->attrs.as<TransposeAttrs>()) {
if (attr->axes.defined()) {
for (int i = 0; i < ndim; ++i) {
int64_t axis = attr->axes[i];
axis += (axis < 0) ? ndim : 0;
attr_axes.push_back(axis);
}
} else {
// Empty axes means reverse
for (int i = ndim - 1; i >= 0; --i) {
attr_axes.push_back(i);
}
}
} else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
Layout src_layout(attr->src_layout);
Layout dst_layout(attr->dst_layout);
for (int i = 0; i < ndim; ++i) {
attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
}
} else {
CHECK(false) << "Expected transpose or layout_transform, but got "
<< Downcast<Op>(call->op)->name;
}
return std::move(attr_axes);
}

private:
/*! \brief Pattern input */
DFPattern x_;
Expand Down
166 changes: 166 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,176 @@ def expected3():
y = relay.transpose(y, axes=[0, 2, 3, 1])
return relay.Function([x], y)

# Test a series of transpose and rank changing layout_transform
def before4():
"""
Simplify transpose->layout_transform and its inverse.
Input:
NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC
Simplified:
NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
"""
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")
y = relay.transpose(x, axes=[0, 3, 1, 2])
y = relay.layout_transform(y, "NCHW", "NCHW4c")
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.transpose(y, axes=[0, 2, 3, 1])
return relay.Function([x], y)

def expected4():
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC
return relay.Function([x], y)

def before5():
"""
Simplify layout_transform->layout_transform and its inverse.
Input:
NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC
Simplified:
NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
"""
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW
y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW
y = relay.layout_transform(y, "NCHW", "NHWC") # To NHWC
return relay.Function([x], y)

def expected5():
x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c
y = relay.nn.relu(y)
y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC
return relay.Function([x], y)

def before6():
"""
Remove trivial layout_transform->layout_transform.
Input:
NCHW -> NHWC -> NCHW -> op
Simplified:
NHWC -> op
"""

x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.layout_transform(x, "NCHW", "NHWC")
y = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected6():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

def before7():
"""
Remove trivial layout_transform->layout_transform.
Input:
NCHW4c -> NCHW8c -> NCHW4c -> op
Simplified:
NCHW4c -> op
"""
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
y = relay.layout_transform(y, "NCHW8c", "NCHW4c")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected7():
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

def before8():
"""
Simplify layout_transform->layout_transform with rank contraction and expansion
Input:
NCHW4c -> NCHW -> NCHW8c -> op
Simplified:
NCHW4c -> NCHW8c -> op
"""
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.layout_transform(x, "NCHW4c", "NCHW")
y = relay.layout_transform(y, "NCHW", "NCHW8c")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected8():
x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
y = relay.nn.relu(y)
return relay.Function([x], y)

def before9():
"""
Remove trivial layout_transform->layout_transform.
Input:
NCHW -> NCHW4c -> NCHW -> op
Simplified:
NCHW -> op
"""
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.layout_transform(x, "NCHW", "NCHW4c")
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected9():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.nn.relu(x)
return relay.Function([x], y)

def before10():
"""
Simplify layout_transform->layout_transform without rank change to transpose.
Input:
NCHW -> NHWC -> CHWN -> op
Simplified:
NCHW -> CHWN -> op
"""
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.layout_transform(x, "NCHW", "NHWC")
y = relay.layout_transform(y, "NHWC", "CHWN")
y = relay.nn.relu(y)
return relay.Function([x], y)

def expected10():
x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
y = relay.transpose(x, axes=[1, 2, 3, 0])
y = relay.nn.relu(y)
return relay.Function([x], y)

for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
[before4(), expected4()],
[before5(), expected5()],
[before6(), expected6()],
[before7(), expected7()],
[before8(), expected8()],
[before9(), expected9()],
[before10(), expected10()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())
Expand Down

0 comments on commit 396a09e

Please sign in to comment.