-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid unnecessary copy in TensorSource #8849
Conversation
2b3b31f
to
636a787
Compare
636a787
to
e483f51
Compare
Hi @ysiraichi, just follow up on offline discussion on the copy operation. PTAL at the PR, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a side note, we can use the DLPack machinery for doing the CUDA to XLA:CUDA transfer (that wasn't implemented at the time I worked on this). I will open an issue for this.
// The purposes of copy are: | ||
// 1. Ensure the memory is contiguous, which is expected by PJRT. | ||
// 2. Move CUDA tensor to CPU since we cannot pass CUDA memory to PJRT now. | ||
// 3. Cast data type. | ||
// We can avoid if copy is not needed. | ||
if (tensor.device() == at::kCPU && tensor.is_contiguous() && | ||
tensor.dtype() == target_torch_type) { | ||
tensor_ = std::move(tensor); | ||
} else { | ||
// TODO(ysiraichi): check, first, if tensor lives in a device that the | ||
// current PjRt client has access. If so, we don't need to go through the | ||
// CPU. | ||
tensor_ = std::move(tensor.to( | ||
at::TensorOptions().device(at::kCPU).dtype(target_torch_type), | ||
/*non_blocking=*/false, | ||
/*copy=*/true, at::MemoryFormat::Contiguous)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand it, tensor.to(...)
(without the copy
argument) already checks whether it should actually copy or not. So, what do you think of reverting to the old tensor.to(...)
usage, but removing the copy
argument, instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ysiraichi, I didn't find a tensor.to(...)
without the copy
arg in C++, is it only in python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. But, I think we can just /* copy= */false
.
In the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left one small NIT. Otherwise, LGTM. If you could address the NIT before submission, that might be nice.
tensor_ = std::move(tensor); | ||
} else { | ||
TORCH_LAZY_COUNTER("AtenSourceTensorCopy", 1); | ||
// TODO(ysiraichi): check, first, if tensor lives in a device that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: I personally prefer to have TODOs linked to issues, and then having those assigned to people. That way, things can be more easily followed-up if a contributor is no longer active
@ysiraichi @tengyifei Actually took a 2nd thought on this. Skipping the copy seems to be unsafe if the underlying PJRT transfer is async.
|
Hmm... I don't think this is the case. That's because PyTorch tensors are ref-counted in the C++ side (unless otherwise specified). So, if we hold a C++ |
By the way, I still don't think you need the if-else there. I believe you can just leave the old |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thank you so much @ysiraichi for the suggestion! Updated accordingly. |
Avoid
at::Tensor
copy inTensorSource
if it's not necessary.The copy operations are needed under 2 cases:
The copy operation needs to be blocking, since the transfer operation depends on the copied tensor.