Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix as_strided for inputs smaller than the arguments specification. #5914

Merged
merged 7 commits into from
Nov 28, 2023

Conversation

ysiraichi
Copy link
Collaborator

Fix: #5719

This PR introduces a base_ attribute for XLATensor. It keeps track of the tensor whose storage would be aliased by the outer tensor due to a view operation.

a = torch.rand(5, device=xm.xla_device())  # base_ is undefined
b = a + a  # base_ is undefined
c = b[2:]  # base_ is b
d = c.as_strided((5,), (1,), 0)  # uses b (the base_ tensor), instead of c, as input
                                 # base_ is b

@JackCaoG
Copy link
Collaborator

lol do you mind resolve the conflict?

@ysiraichi
Copy link
Collaborator Author

@JackCaoG I think this is ready for another round of reviews. Could you take a look at it?

const at::Tensor& GetRootBase(const at::Tensor& tensor);
// Sets the base tensor of a given XLATensor. Convenient function
// to be used when returning tensors.
XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we ever expect base to be on non-xla device? If not can we add an explict check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I don't think so, since we got to a XLA dispatched kernel. Will add the check.

torch_xla/csrc/tensor.h Outdated Show resolved Hide resolved
@ysiraichi ysiraichi merged commit 4ac255a into pytorch:master Nov 28, 2023
18 checks passed
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
…pytorch#5914)

* Add test.

* Create `base_` tensor for views.

* Use base tensor in `as_strided` operation.

* Set base tensor of `as_strided`.

* Fix lint errors.

* Fix for disabled functionalization.

* Address review.
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Dec 1, 2023
…pytorch#5914)

* Add test.

* Create `base_` tensor for views.

* Use base tensor in `as_strided` operation.

* Set base tensor of `as_strided`.

* Fix lint errors.

* Fix for disabled functionalization.

* Address review.
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
…pytorch#5914)

* Add test.

* Create `base_` tensor for views.

* Use base tensor in `as_strided` operation.

* Set base tensor of `as_strided`.

* Fix lint errors.

* Fix for disabled functionalization.

* Address review.
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
…#5914)

* Add test.

* Create `base_` tensor for views.

* Use base tensor in `as_strided` operation.

* Set base tensor of `as_strided`.

* Fix lint errors.

* Fix for disabled functionalization.

* Address review.
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
…#5914)

* Add test.

* Create `base_` tensor for views.

* Use base tensor in `as_strided` operation.

* Set base tensor of `as_strided`.

* Fix lint errors.

* Fix for disabled functionalization.

* Address review.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Lowering as_strided errors for input tensors smaller than size-stride specs.
2 participants