Skip to content

Commit

Permalink
fix first-order AD on tuple arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh committed Nov 2, 2020
1 parent 3222cad commit d3da8f3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,22 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
return ret;
}

Expr UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
return ll->Push(Add(arg, grad));
} else if (auto* tt = t.as<TupleTypeNode>()) {
Array<Expr> updates;
for (size_t i = 0; i < tt->fields.size(); ++i) {
updates.push_back(this->UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)),
ll->Push(GetField(grad, i)), ll));
}
return ll->Push(Tuple(updates));
} else {
LOG(FATAL) << "unsupported arg type of operator: " << t;
throw;
}
}

ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
ICHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined";
Expand All @@ -198,8 +214,10 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
ICHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
auto ad_arg = args[i]->get<ADTensor>();
auto ad_arg_type = ad_arg.forward->checked_type();
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
this->UpdateGrad(ad_arg_type, ad_arg.reverse, rev[i], ll);
}
});
return ret;
Expand Down
33 changes: 33 additions & 0 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,29 @@ def _test_tuple(mode):
tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))


def _test_tuple_argument(mode):
shape = (2, 3)
dtype = "float32"
tensor_type = relay.TensorType(shape, dtype)
fields = 3
tuple_type = relay.TupleType([tensor_type] * fields)
tup = relay.var("tup", type_annotation=tuple_type)
body = relay.TupleGetItem(tup, 0)
for i in range(1, fields):
body = relay.add(body, relay.TupleGetItem(tup, i))
func = relay.Function([tup], body)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func, mode=mode))
xs = [rand(dtype, *shape) for _ in range(fields)]
xs_np = np.array([x.asnumpy() for x in xs])
expected_forward = np.sum(xs_np, axis=0)
ex = create_executor()
forward, grad = ex.evaluate(back_func)(tuple(xs))
tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
for field in grad[0]:
tvm.testing.assert_allclose(field.asnumpy(), np.ones_like(field.asnumpy()))


def test_tuple():
_test_tuple("higher_order")

Expand All @@ -263,6 +286,16 @@ def test_tuple_first_order():
_test_tuple("first_order")


@pytest.mark.xfail(raises=tvm.error.TVMError)
def test_tuple_argument():
# fails until we add support for top-level tuple arguments in higher-order AD
_test_tuple_argument("higher_order")


def test_tuple_argument_first_order():
_test_tuple_argument("first_order")


def test_pow():
mod = tvm.IRModule()
p = Prelude(mod)
Expand Down

0 comments on commit d3da8f3

Please sign in to comment.