Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the values properly updated (self.key_cache and self.value_cache) since you are doing a copy ? Also do you see a memory increase from this modification ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I afraid I have an issue with how to check this explicitly since adding
print()here does not work, Can you suggest what and how I should check? However, documentation on clone says: "This function is differentiable, so gradients will flow back from the result of this operation to input. To create a tensor without an autograd relationship to input see detach()." So, I think updates to cloned tensor should propagate back, but I honestly don't know how to check this explicitly.I did not check, but my expectation is that there will be memory increase since what
clone()dos is to create a new tensor.Also, I noted this issue on pytorch side:
I think that the better fix here will actually be to if-check whether tensors are on wrong device and if they are call
to(copy=True). But this assumes a fix to the above pytorch issue. I also see this change (merged) on pytorch side pytorch/pytorch#132529 where they seem to wrokaround the issue with non-mutable tensors after calling.to()for some other particular case (search for 131679 issue in the change). If I read it correctly, they are doing the same - callingclone()and doing operation on a cloned tensor.@bdhirsh, @SherlockNoMad, @albanD : may I ask your help to weigh in on this issue?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah just commented here pytorch/pytorch#131679 (comment) - Tugsuu is (tentatively) going to look into making that error less intrusive so this case should "just work" without a clone. It'll probably take some time though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh : thank you so much for looking into this. In a meanwhile we are considering 2 workarounds on HF side:
@bdhirsh, can you guide us whether 1st or 2nd WA is preferable or you just recommend to wait couple days for the fix on PyTorch side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually a bit confused by the code. For example - in https://github.com/huggingface/transformers/pull/33287/files, I don't think that code will do anything:
the above code won't actually change the device of
self.key_cache(it will remain on cuda, even ifkey_states lives on cpu).self.key_cache[0]is taking a view/slice off ofself.key_cache`, and you can't change the device of "part" of a tensor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my meta question is: in general, don't you know the device of your kv cache ahead of time? Can you ensure that your model (and its kv cache) are initialized on the right device to begin with, so that when you export, you don't need to do any device conversions? You probably don't want any device conversions in the runtime path (inside of your exported program) anyway, since they will hurt perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my meta question is: in general, don't you know the device of your kv cache ahead of time? Can you ensure that your model (and its kv cache) are initialized on the right device to begin with, so that when you export, you don't need to do any device conversions? You probably don't want any device conversions in the runtime path (inside of your exported program) anyway, since they will hurt perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#33303 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool :)