Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Tvm broadcast backward #15938

Merged
merged 3 commits into from
Aug 22, 2019
Merged

Tvm broadcast backward #15938

merged 3 commits into from
Aug 22, 2019

Conversation

hzfan
Copy link
Contributor

@hzfan hzfan commented Aug 18, 2019

Description

Use tvm to implement vadd backward with broadcast.

Changes

  • add vadd backward

Comments

  • As vadd is not a Numpy op, 0-dim and 0-size is not supported.
  • For now, I implemented infra-level things on op-level, like
    • dispatch of different req
    • dispatch of backward

I think code for these may be further reused in the future. It'll be great if we can have a consistent interface for tvm op and hide the dispatch of things like req and backward.

Thank @yzhliu and @junrushao1994 for the brilliant "compressed bit string" idea.

@junrushao
Copy link
Member

CC @tqchen if you have bandwidth

@haojin2 haojin2 self-assigned this Aug 19, 2019
@haojin2 haojin2 added the Numpy label Aug 19, 2019
@hzfan hzfan force-pushed the bc_pr branch 3 times, most recently from 9fc3389 to ad97a7e Compare August 19, 2019 09:22
return b, c


def reduce_axes(X, axes, reducer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some comments to elaborate the idea? e.g., meaning of axes. also can we move it to somewhere else so that other operators can reuse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added in ufunc.py

@@ -48,3 +50,71 @@ def vadd_gpu(dtype, ndim):
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B, C]


def assign_by_req(a, req):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to sth like common.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use the existing contrib/tvmop/utils.py or create a contrib/tvmop/basic/common.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utils.py is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved.

funcname += "req_";
MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, {
if (req_type == kWriteTo) {
funcname += "kWriteTo";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligned.

// dispatch by backward
std::vector<int> ov, iv;
const TBlob& ograd = inputs[0], igrad = outputs[k];
bool flag = ograd.size(0) != igrad.size(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use int and explicitly assign the value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about expand it into a if-else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

}
TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end());
TBlob ograd_tvm(ograd.reshape(oshape).dltensor());
TBlob igrad_tvm(igrad.reshape(ishape).dltensor());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add some comments to elaborate the ideas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in ufunc.py

std::vector<int> ov, iv;
const TBlob& ograd = inputs[0], igrad = outputs[k];
bool flag = ograd.size(0) != igrad.size(0);
for (int i = 0; i < ndim; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, there seems to be an assumption that ograd.ndim = igrad.ndim, which is not necessarily true. I think you need to prepend axes before igrad if igrad.ndim < ograd.ndim and then use the logic here.

Copy link
Contributor Author

@hzfan hzfan Aug 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, igrad.ndim = ograd.ndim is assumed.

@yzhliu suggests padding the input to 5-dim, which is the largest possible dim supported by this op. The padding will 1) reduce the number of kernels (by a factor of 5) 2) handle the igrad.ndim < ograd.ndim issue. But there may be loss in performance.

I think prepending axes before igrad to make it ograd.dim requires more kernels, but the performance is better. It is a tradeoff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if my understanding is wrong, but don't you still need kernels generated for ndims < 5 since you will collapse consecutive dimensions where reduction is performed? For example, given a 5d shape (2, 3, 4, 5, 6), and perform reduction on axis=(1, 2), the tblob will be first reshaped into (2, 12, 30), and then reduce on axis=1. In this case, do you need a kernel generated for 3D shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can pad the shape after dimension collapse. In this case, the tblob will be reshaped into (2, 12, 30, 1, 1) and then reduce on axis=[1, 3].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I am in favor of the approach with less kernels generated. We can revisit the performance concern if that turns out to be an issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a new version, where the inputs and outputs are padded to 5 dim.

@yzhliu yzhliu merged commit 9023256 into apache:master Aug 22, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants