common : fix state save in common_prompt_batch_decode#23468
Conversation
|
I tested it on my toy Llama-3.1 example from the bug report and a longer prompt with DeepSeek V3.2 and it looks good, model output during last prompt token replay now matches the model output when memory state was written. I see there's one more problem though where state is not saved at all if I have a long prompt and there's no |
|
@danbev I printed some debug messages during processing of a long prompt: It looks like |
This commit addresses a bug in common_prompt_batch_decode that affects the session state store/restore in completion.cpp and save-load-state.cpp. The motivation for this is that currently the code is saving n-1 tokens in both the session_tokens and in the KV cache. Then when loading the session tokens, and if the prompt matches, it would replay the last saved token (n-1) into the next position, effectively replaying the same token in the wrong position. The fix is to store all n tokens in session_tokens, while the memory state only reflects n-1 processed tokens as the saving happens before the last token is decoded in common_prompt_batch_decode. I ran both completion.cpp and save-load-state.cpp with a transformer, a recurrent, and a hybrid model. Resolves: ggml-org#23400 Co-authored-by: fairydreaming <166155368+fairydreaming@users.noreply.github.com>
b698524 to
411c926
Compare
|
@fairydreaming Thanks for reporting and testing this! |
I confirm that it works now, first in writer I see: then in reader: |
|
@fairydreaming Thanks! |
Overview
This commit addresses a bug in common_prompt_batch_decode that affects the session state store/restore in completion.cpp and save-load-state.cpp.
Additional information
The motivation for this is that currently the code is saving n-1 tokens in both the session_tokens and in the KV cache. Then when loading the session tokens, and if the prompt matches, it would replay the last saved token (n-1) into the next position, effectively replaying the same token in the wrong position.
The fix is to store all n tokens in session_tokens, while the memory state only reflects n-1 processed tokens as the saving happens before the last token is decoded in common_prompt_batch_decode.
I ran both completion.cpp and save-load-state.cpp with a transformer, a recurrent, and a hybrid model.
Requirements
Resolves: #23400