Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bug: Gemma2 the past_key_value.update() function has added a new parameter "sliding_window" to support the _sliding_update function. #31786

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kkk935208447
Copy link

@kkk935208447 kkk935208447 commented Jul 4, 2024

What does this PR do?

System Info
transformers 4.42.3

Now gemma2 model generates long text that exceeds the window size (>4096), it will report a CUDA error, which seems to be a problem with the failure of the _sliding_update function in HybridCache. The error is as follows:

# pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

input_text = "Write me a poem about Machine Learning." * 800
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids,max_new_tokens = 1150)
print(tokenizer.decode(outputs[0]))
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [] Did you write any new necessary tests?

Who can review?

@ArthurZucker

…eter "sliding_window" to support the `_sliding_update` function.
@LysandreJik
Copy link
Member

Could you take a look when you have a minute @sanchit-gandhi ?

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kkk935208447! Thanks for opening this PR 🤗 I believe the bug you're experiencing is actually un-related to the sliding window mechanism, which I've explained below is correctly updated using the current code on main. Looking at your codesnippet, it looks like there are CUDA device errors which are being thrown, which might be related to how you're moving the tensors across devices. Could you please paste the full traceback of your error? That would help massively in pin-pointing where you're hitting the error. Thanks!

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1])
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for propagating the update here! It's not essential to keep the diff file sync'd with modeling_gemma2.py, since it's just used in the process of the integration and we don't actually make any imports from this file, but helpful to motivate the other changes in this PR!

@@ -338,7 +339,8 @@ def forward(
"sliding_window": self.sliding_window,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we set sliding_window to the cache_kwargs here, and then pass the dict of cache_kwargs to the .update function for the cache 2 lines later. Hence, the sliding window should already be handled with the current code.

@@ -338,7 +339,8 @@ def forward(
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here how we pass the cache_kwargs to the update function of the hybrid cache, where cache_kwargs are defined as {..., "sliding_window": self.sliding_window"} a few lines above. Hence, there should be no need to pass sliding_window explicitly!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#31775
There is a similar problem to mine here. I don't think it's an issue with my personal device, as I've debugged it multiple times and this phenomenon still occurs.
"The cache was created with alternating max seq lenghts of 4k and 8k, but all layers were being updated as if they were 8k, causing out-of-bounds errors and CUDA exceptions."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem is with the function definition:

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
        sliding_window: Optional[int] = None,
    )

cache_kwargs is not actually a kwargs argument (missing the **), it's just a dict. Adding a sliding_window key to it doesn't affect the separate sliding_window argument to the update function, which just keeps its default value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants