-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Runtime] Enable set_input_zero_copy in GraphRuntime #3416
Conversation
Can you add test cases for |
Will do |
@hlu1 Added tests. |
src/runtime/graph/graph_runtime.cc
Outdated
for (size_t i = 0; i < data_entry_.size(); ++i) { | ||
int storage_id = attrs_.storage_id[i]; | ||
CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size()); | ||
data_entry_[i] = | ||
storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); | ||
const DLTensor* tmp = data_entry_[i].operator->(); | ||
data_alignment_[i] = GetDataAlignment(*tmp); | ||
dltensor_entry_shapes_[i].resize(tmp->ndim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape info is stored in attrs_.shape
. There is no need to save it into dltensor_entry_shapes_
src/runtime/graph/graph_runtime.cc
Outdated
|
||
// check the consistency of input shape | ||
CHECK_EQ(data_alignment_[eid], GetDataAlignment(*data_ref)); | ||
CHECK(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TVM assumes 64-byte alignment of the memory address. Failing to that will cause subtle coredumps, for example vmovaps
on unaligned memory. When we setup external input, we need to guard that because we don't know how the memory is allocated.
src/runtime/graph/graph_runtime.cc
Outdated
CHECK_EQ(shape[i], data_ref->shape[i]); | ||
} | ||
} else { | ||
int64_t acc_prev = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we hit this case? Shouldn't the shapes match exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we don't need this any more.
* \param index The input index. | ||
* \param data_ref The input data that is referred. | ||
*/ | ||
void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check if the device_type
and device_id
match as well, for the heterogenous case.
@@ -206,6 +244,12 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { | |||
CHECK_EQ(data_entry_[eid].use_count(), 1); | |||
data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); | |||
CHECK_GT(data_entry_[eid].use_count(), 1); | |||
const DLTensor* tmp = data_entry_[eid].operator->(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the ShareParams()
function needs to be updated here. If the params are set by sharing with another graph runtime, they do not need to set again. The shape and alignment should be the same as before the sharing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I met with tests failure if I didn't do this.
} | ||
|
||
// Update the data pointer for each argument of each op | ||
for (auto& op_arg : op_args_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking, maybe we can change the API to accept an array of DLTensors, update the entries in the data_entries_ for all of them, and then call SetupOpExecs()
. That way, you don't need to save the input_entry_ids
in op_args_
and the code can be much cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had an implementation like this but it turned out the be more difficult.
- To update
data_entries_
, you need to create a DLManagedTensor, which is more small allocation. - Running
SetupOpExecs()
seems a bit heavy.
@ZihengJiang, @eqy, @kevinthesun, @icemelon9, can you guys take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
include/tvm/runtime/ndarray.h
Outdated
@@ -33,6 +33,8 @@ | |||
namespace tvm { | |||
namespace runtime { | |||
|
|||
size_t GetDataAlignment(const DLTensor& arr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously this API was private so it was written in a not very thoughtful way. I am debating whether or not we should include it as a public API. Perhaps just have two inlined version of this function for now internally. As the behavior was mainly return kTempAllocaAlign. If it is a public API, we need to properly document it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can make them local functions.
src/runtime/graph/graph_runtime.cc
Outdated
const DLTensor* old_t = data_entry_[eid].operator->(); | ||
|
||
// check the consistency of input | ||
CHECK_EQ(data_alignment_[eid], GetDataAlignment(*data_ref)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if we need to introduce data_alignment_
. It looks that we can get alignment from data_entry_[eid]
as well, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I originally just wanted to avoid computing this repeatedly. What do you think? I don't have a strong opinion about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we probably should not introduce extra members when not really necessary. The compute is cheap and used by the other field as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's extra dozen bytes of memory, but the trade-off is that you can avoid doing the compute for each inference (we are talking about million). Sounds reasonable?
@tqchen ping |
* Enable set_input_zero_copy in GraphRuntime * Fix LoadParams * Fix * lint * Fix remote context issue * Fix * Remove LOG * Remove unused variables * Add tests * works * More test scenarios * make it simpler * Remove unnecessary changes * Address comments * More comments * Address comments * Fix build
* Enable set_input_zero_copy in GraphRuntime * Fix LoadParams * Fix * lint * Fix remote context issue * Fix * Remove LOG * Remove unused variables * Add tests * works * More test scenarios * make it simpler * Remove unnecessary changes * Address comments * More comments * Address comments * Fix build
Does it support gpu zero copy, we tried zero-copy with gpu context which generates error |
When integrating with other framework such as PyTorch and etc, it's more desirable to avoid unnecessary copies of activations and weights when hooking up TVM runtime and run.
cc: @ajtulloch @hlu1