-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RWKV-4 #22797
Add RWKV-4 #22797
Conversation
# TODO: maybe jit, otherwise move inside forward | ||
def extract_key_value(self, hidden, state=None): | ||
# Mix hidden with the previous timestep to produce key, value, receptance | ||
shifted = self.time_shift(hidden) if state is None or hidden.size(1) != 1 else state[1][:, :, self.layer_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this shift mistakenly drops the previous hidden
in state
when provided with a sequence of length larger than 1. In the case of state is not None and hidden.size(1) != 1
, it should cut the last token and prepend with the token from state
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, thanks for the pointer!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed now!
The documentation is not available anymore as the PR was closed or merged. |
- fix common tests - fix configuraion default values - add CI test for checking state computation - fix some CI tests
- fix config docstring - fix failing tests
- add output_attention / output_hidden_states - override test_initialization - fix failing CIs
- fix sharded case - add new arguments
IMO the model is in a nice shape! Would love to have a round of review before I transfer the weights on the proper organization! |
@@ -93,15 +93,20 @@ | |||
|
|||
|
|||
class ConfigTester(object): | |||
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs): | |||
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why this change is added here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because in RWKV from my understanding there is no notion of attention heads. This default test expects to always have num_attention_heads
so I decided to make it slightly modular to accept custom common_properties
. I thought as we might have models like that in the future maybe it's a good idea to make it slightly modular.
Happy to revert it / maybe override that test if you think it's better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation!
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
attention_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a "mock" attention mask here so that pipeline
won't fail complaining that the attention mask is outputted by the tokenizer and not used by the model. As we want to add the attention mask support anyway in the future, I thought it's the simplest solution now. To reproduce:
from transformers import pipeline
model_id = "ybelkada/rwkv-4-169m-pile"
prompt = "Hello"
pipe = pipeline("text-generation", model=model_id)
print(pipe(prompt, max_length=10))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sgugger for taking the lead on this! Learned a lot 🔥
@younesbelkada In README.md The name should be "Bo Peng" (Peng is the surname) instead of "Peng Bo" :) |
forward_func = rwkv_cuda_kernel.forward_with_state_bf16 | ||
else: | ||
forward_func = rwkv_cuda_kernel.forward_with_state | ||
# TODO: update CUDA kernel so it uses the initial state provided here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this todo have been done?
|
||
@staticmethod | ||
# g stands for grad | ||
def backward(ctx, g_output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any plan on supporting gradients on states? It would make chaining wkvs in training possible, getting rid of the seqlen limitation. It will also match the _with_state
variant of WKV forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes sense, however I would advocate to do that in a follow up PR to at least unlock the model addition for users that already want to use the model with transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I help on that later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure ! You are more than welcome to help us on that
time_decay, time_first, key, value, output = ctx.saved_tensors | ||
# The CUDA kernel will fill those tensors. | ||
g_time_decay = torch.empty_like( | ||
time_decay, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I read it right, time_decay
/time_first
is of shape (C, )
while the CUDA kernel requires gw
and gu
of shape (B, C)
(see wkv_cuda.cu:140-141
). It may cause VRAM overflow as well as wrong results. I didn't set up the environment to do the test, but I suspect the current g_time_decay/first
after the summation in lines 192-193 will unexpectedly become scalars, which can verify my guess.
hi @sgugger, thanks A TON for this merge! I am trying to train a new model of type and facing the following error:
From what I can see, the backward function of RwkvLinearAttentionBackward does not mention a g_state - should gradients be computed for the state, I guess not? Any pointers as to how I can resolve this will be very much appreciated! |
I managed to get the code to run with some changes to the forward() and backward() functions: class RwkvLinearAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
batch_size, seq_len, hidden_size = key.size()
if seq_len > rwkv_cuda_kernel.max_seq_length:
raise ValueError(
f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
f"{rwkv_cuda_kernel.max_seq_length} with this model."
)
if batch_size * hidden_size % min(hidden_size, 32) != 0:
raise ValueError(
f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
f"multiple of {min(hidden_size, 32)}."
)
ctx.input_dtype = key.dtype
if (
time_decay.device.type != "cuda"
or time_first.device.type != "cuda"
or key.device.type != "cuda"
or value.device.type != "cuda"
):
raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
time_decay = -torch.exp(time_decay.float().contiguous())
if key.dtype == torch.float16:
time_first = time_first.float()
key = key.float()
value = value.float()
time_first = time_first.contiguous()
key = key.contiguous()
value = value.contiguous()
# The CUDA kernel will fill this tensor.
output = torch.empty_like(key, memory_format=torch.contiguous_format)
if return_state or state is not None:
if state is None:
state = torch.zeros(
batch_size,
hidden_size,
3,
dtype=torch.float32,
device=key.device,
memory_format=torch.contiguous_format,
)
state[:, :, 2] -= 1e38
else:
state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
if key.dtype == torch.bfloat16:
forward_func = rwkv_cuda_kernel.forward_with_state_bf16
else:
forward_func = rwkv_cuda_kernel.forward_with_state
forward_func(time_decay, time_first.to(key.dtype), key, value, output, state)
else:
forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
forward_func(time_decay, time_first.to(key.dtype), key, value, output)
ctx.save_for_backward(time_decay, time_first, key, value, output)
if state is not None:
state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
return output.to(ctx.input_dtype), state def backward(ctx, g_output, g_state):
input_dtype = ctx.input_dtype
time_decay, time_first, key, value, output = ctx.saved_tensors
# The CUDA kernel will fill those tensors.
g_time_decay = torch.empty_like(
time_decay,
memory_format=torch.contiguous_format,
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty_like(
time_first,
memory_format=torch.contiguous_format,
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
if input_dtype == torch.float16:
g_output = g_output.float()
backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
backward_func(
time_decay,
time_first.to(key.dtype),
key,
value,
output,
g_output.contiguous(),
g_time_decay,
g_time_first,
g_key,
g_value,
)
#g_time_decay = torch.sum(g_time_decay, dim=0)
#g_time_first = torch.sum(g_time_first, dim=0)
return (
g_time_decay.to(input_dtype),
g_time_first.to(input_dtype),
g_key.to(input_dtype),
g_value.to(input_dtype),
None,
None
) One problem I run into now is that although I'm trying to train a fairly small model (12 layers, 256 hidden size, 64 context size) I can only train with a very small batch size (16) on a 40GB A100 card. For comparison, a RoBERTa model with a similar size allows for a bs of 256. This seems counterintuitive to me, but I might be wrong. Another issue I observed is instability: in some cases, within the first 3 steps of training the loss goes from something normal like 10 to 90543067814198.3 and then to 0.0. This seems to happen more when bf16 training is disabled and at higher batch sizes when bf16 training is enabled. |
@YovaKem Would you mind try change this # The CUDA kernel will fill those tensors.
g_time_decay = torch.empty_like(
time_decay,
memory_format=torch.contiguous_format,
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format) to # The CUDA kernel will fill those tensors.
g_time_decay = torch.empty(
key.shape[0], key.shape[2],
memory_format=torch.contiguous_format,
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
)
g_time_first = torch.empty(k.shape[0], k.shape[2], memory_format=torch.contiguous_format) I suspect there's an overflow in the current code, as mentioned above in the review comment but not tested yet. The binary distribution on PyPI does not include the cuda kernels XD Also, the gradient of the state should be computed, but the current kernel is not doing it. Later after I setup the env I'll open the PR. |
Thanks @Blealtan! I guess you meant # The CUDA kernel will fill those tensors.
g_time_decay = torch.empty(
key.shape[0], key.shape[2],
memory_format=torch.contiguous_format,
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
).to(key.device)
g_time_first = torch.empty(
key.shape[0], key.shape[2],
memory_format=torch.contiguous_format,
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
).to(key.device) This seems to solve both the OOM issue and the instability! One question re your comment of state gradients - I now saw this
In what cases is the _with_state variant used? As far as I can see the model I'm training is not passing states at all during the forward step. Is that something that only becomes relevant an inference time when the model is used like an RNN? |
Hey @sgugger how did you prepare the models? Could you point us how to convert original .pth or .safetensors model to your format? Thanks! PS |
@lambdaofgod The logic used to convert the RWKV checkpoints from BlinkDL to HF format can be found in the conversion script. |
@YovaKem AFAIK, |
I have no idea why the CUDA kernels all disappeared from the pacakge on Pypi (it's not just RWKV, but all models using custom kernels). Will investigate later today and post a patch release when I find a solution. |
Normally custom kernels should be included in 4.29.2, sorry for the inconvenience. We added stronger to checks to make sure they don't disappear again in a future release. |
Hi, can i ask a simple question about RWKV kernel? The rwkv model without customized kernel uses a transformers/src/transformers/models/rwkv/modeling_rwkv.py Lines 223 to 241 in 3658488
I am not familiar with cuda kernel. So i am not sure whether the customized cuda kernel still computes sequentially and delivers a faster |
Putting this here so it doesn't get lost. I am trying to run microsoft guidance (https://github.com/microsoft/guidance) on RWKV through transformers and I am getting an error
which can be reproduced here: https://gist.github.com/fullstackwebdev/a6523374e6687825fcb92ca74048c12b |
@fullstackwebdev |
* First draft of RWKV-4 * Add support for generate * Style post-rebase * Properly use state * Write doc * Fix doc * More math * Add model to README, dummies and clean config * Fix init * multiple fixes: - fix common tests - fix configuraion default values - add CI test for checking state computation - fix some CI tests * correct tokenizer * some tweaks - fix config docstring - fix failing tests * fix CI tests - add output_attention / output_hidden_states - override test_initialization - fix failing CIs * fix conversion script - fix sharded case - add new arguments * add slow tests + more fixes on conversion script * add another test * final fixes * change single name variable * add mock attention mask for pipeline to work * correct eos token id * fix nits * add checkpoints * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * add `tie_word_embeddings` in docstring * change tensor name * fix final nits * Trigger CI --------- Co-authored-by: younesbelkada <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: amyeroberts <[email protected]>
* First draft of RWKV-4 * Add support for generate * Style post-rebase * Properly use state * Write doc * Fix doc * More math * Add model to README, dummies and clean config * Fix init * multiple fixes: - fix common tests - fix configuraion default values - add CI test for checking state computation - fix some CI tests * correct tokenizer * some tweaks - fix config docstring - fix failing tests * fix CI tests - add output_attention / output_hidden_states - override test_initialization - fix failing CIs * fix conversion script - fix sharded case - add new arguments * add slow tests + more fixes on conversion script * add another test * final fixes * change single name variable * add mock attention mask for pipeline to work * correct eos token id * fix nits * add checkpoints * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * add `tie_word_embeddings` in docstring * change tensor name * fix final nits * Trigger CI --------- Co-authored-by: younesbelkada <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
This PR is a draft and while there is a working implementation of the model, there is still a lot to do :-)
This PR adds the RWKV model from BlinkDL/RWKV-LM which is a RNN-like Transformers: it has an attention layer and a feed-forward, but the attention is linear and can be expressed recurrently (more details coming in the doc page of the model).
Here is a code snippet to play with the model:
To use the chat models (called raven):
Fixes #20737
Fixes #17230
TODO:
cc @ArthurZucker and @younesbelkada