Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,8 @@ def update(
cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
k_out = self.key_cache[layer_idx].clone()
v_out = self.value_cache[layer_idx].clone()
Comment on lines +1090 to +1091
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the values properly updated (self.key_cache and self.value_cache) since you are doing a copy ? Also do you see a memory increase from this modification ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the values properly updated (self.key_cache and self.value_cache) since you are doing a copy ?

I afraid I have an issue with how to check this explicitly since adding print() here does not work, Can you suggest what and how I should check? However, documentation on clone says: "This function is differentiable, so gradients will flow back from the result of this operation to input. To create a tensor without an autograd relationship to input see detach()." So, I think updates to cloned tensor should propagate back, but I honestly don't know how to check this explicitly.

Also do you see a memory increase from this modification ?

I did not check, but my expectation is that there will be memory increase since what clone() dos is to create a new tensor.

Also, I noted this issue on pytorch side:

I think that the better fix here will actually be to if-check whether tensors are on wrong device and if they are call to(copy=True). But this assumes a fix to the above pytorch issue. I also see this change (merged) on pytorch side pytorch/pytorch#132529 where they seem to wrokaround the issue with non-mutable tensors after calling .to() for some other particular case (search for 131679 issue in the change). If I read it correctly, they are doing the same - calling clone() and doing operation on a cloned tensor.

@bdhirsh, @SherlockNoMad, @albanD : may I ask your help to weigh in on this issue?

Copy link

@bdhirsh bdhirsh Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah just commented here pytorch/pytorch#131679 (comment) - Tugsuu is (tentatively) going to look into making that error less intrusive so this case should "just work" without a clone. It'll probably take some time though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh : thank you so much for looking into this. In a meanwhile we are considering 2 workarounds on HF side:

  1. This PR (WIP: Clone tensors to be able to mutate #33179)
  2. PR from @guangy10 (Unbreak torch export with static cache #33287)

@bdhirsh, can you guide us whether 1st or 2nd WA is preferable or you just recommend to wait couple days for the fix on PyTorch side?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually a bit confused by the code. For example - in https://github.com/huggingface/transformers/pull/33287/files, I don't think that code will do anything:

# let's say self.key_cache is a tensor with device =='cuda'
# and let's say key_states.device == 'cpu'
self.key_cache[0] = self.key_cache[0].to(device=key_states.device)

the above code won't actually change the device of self.key_cache (it will remain on cuda, even if key_states lives on cpu). self.key_cache[0]is taking a view/slice off ofself.key_cache`, and you can't change the device of "part" of a tensor.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my meta question is: in general, don't you know the device of your kv cache ahead of time? Can you ensure that your model (and its kv cache) are initialized on the right device to begin with, so that when you export, you don't need to do any device conversions? You probably don't want any device conversions in the runtime path (inside of your exported program) anyway, since they will hurt perf

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my meta question is: in general, don't you know the device of your kv cache ahead of time? Can you ensure that your model (and its kv cache) are initialized on the right device to begin with, so that when you export, you don't need to do any device conversions? You probably don't want any device conversions in the runtime path (inside of your exported program) anyway, since they will hurt perf

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my meta question is: in general, don't you know the device of your kv cache ahead of time?

#33303 🤗

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool :)


if cache_position is None:
k_out.copy_(key_states)
Expand Down