From 3e0855b7ce3adc7421a847cb01960c892a623d03 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Sun, 16 Apr 2023 17:47:45 -0400 Subject: [PATCH 01/27] First draft of RWKV-4 --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/rwkv.mdx | 43 + src/transformers/__init__.py | 19 + src/transformers/kernels/rwkv/wkv_cuda.cu | 187 +++++ .../kernels/rwkv/wkv_cuda_bf16.cu | 186 +++++ src/transformers/kernels/rwkv/wkv_op.cpp | 66 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 9 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/rwkv/__init__.py | 69 ++ .../models/rwkv/configuration_rwkv.py | 110 +++ .../rwkv/convert_rwkv_checkpoint_to_hf.py | 171 ++++ src/transformers/models/rwkv/modeling_rwkv.py | 747 ++++++++++++++++++ tests/models/rwkv/__init__.py | 0 tests/models/rwkv/test_modeling_rwkv.py | 281 +++++++ 16 files changed, 1896 insertions(+) create mode 100644 docs/source/en/model_doc/rwkv.mdx create mode 100644 src/transformers/kernels/rwkv/wkv_cuda.cu create mode 100644 src/transformers/kernels/rwkv/wkv_cuda_bf16.cu create mode 100644 src/transformers/kernels/rwkv/wkv_op.cpp create mode 100644 src/transformers/models/rwkv/__init__.py create mode 100644 src/transformers/models/rwkv/configuration_rwkv.py create mode 100644 src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py create mode 100644 src/transformers/models/rwkv/modeling_rwkv.py create mode 100644 tests/models/rwkv/__init__.py create mode 100644 tests/models/rwkv/test_modeling_rwkv.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6f0c0fe3997d40..a6c8fcf5ecc83a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -397,6 +397,8 @@ title: RoCBert - local: model_doc/roformer title: RoFormer + - local: model_doc/rwkv + title: RWKV - local: model_doc/splinter title: Splinter - local: model_doc/squeezebert diff --git a/docs/source/en/model_doc/rwkv.mdx b/docs/source/en/model_doc/rwkv.mdx new file mode 100644 index 00000000000000..71c07b71810afa --- /dev/null +++ b/docs/source/en/model_doc/rwkv.mdx @@ -0,0 +1,43 @@ + + +# RWKV + +## Overview + +The RWKV model was proposed in [this repo](https://github.com/BlinkDL/RWKV-LM) + +TODO, need to write this page + +Tips: + + +This model was contributed by [sgugger](https://huggingface.co/sgugger). +The original code can be found [here](https://github.com/BlinkDL/RWKV-LM). + + +## RwkvConfig + +[[autodoc]] RwkvConfig + + +## RwkvModel + +[[autodoc]] RwkvModel + - forward + - forward_rnn + +## RwkvLMHeadModel + +[[autodoc]] RwkvForCausalLM + - forward + - forward_rnn diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cfab06515a5742..9fd080a5fb5672 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -438,6 +438,7 @@ "SamPromptEncoderConfig", "SamVisionConfig", ], + "models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"], "models.segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"], "models.sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"], "models.sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"], @@ -2360,6 +2361,12 @@ "SAM_PRETRAINED_MODEL_ARCHIVE_LIST", "SamModel", "SamPreTrainedModel", + _import_structure["models.rwkv"].extend( + [ + "RWKV_PRETRAINED_MODEL_ARCHIVE_LIST", + "RwkvForCausalLM", + "RwkvModel", + "RwkvPreTrainedModel", ] ) _import_structure["models.segformer"].extend( @@ -4158,6 +4165,7 @@ ) from .models.roc_bert import ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RoCBertConfig, RoCBertTokenizer from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer +<<<<<<< HEAD from .models.sam import ( SAM_PRETRAINED_CONFIG_ARCHIVE_MAP, SamConfig, @@ -4166,6 +4174,9 @@ SamPromptEncoderConfig, SamVisionConfig, ) +======= + from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig +>>>>>>> First draft of RWKV-4 from .models.segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig from .models.sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig from .models.sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig @@ -5761,10 +5772,18 @@ RoFormerPreTrainedModel, load_tf_weights_in_roformer, ) +<<<<<<< HEAD from .models.sam import ( SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel, +======= + from .models.rwkv import ( + RWKV_PRETRAINED_MODEL_ARCHIVE_LIST, + RwkvForCausalLM, + RwkvModel, + RwkvPreTrainedModel, +>>>>>>> First draft of RWKV-4 ) from .models.segformer import ( SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/kernels/rwkv/wkv_cuda.cu b/src/transformers/kernels/rwkv/wkv_cuda.cu new file mode 100644 index 00000000000000..8337c2eacd1b3b --- /dev/null +++ b/src/transformers/kernels/rwkv/wkv_cuda.cu @@ -0,0 +1,187 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_forward_with_state( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset_s = _b * C * 3 + _c * 3; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + F *__restrict__ const s = _s + _offset_s; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + s[0] = aa; + s[1] = bb; + s[2] = pp; +} + +template +__global__ void kernel_backward( + const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u, + const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y, + const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, + F *__restrict__ const _gv +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_with_state<<>>(B, T, C, w, u, k, v, y, s); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu b/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu new file mode 100644 index 00000000000000..d9ff0f427b6193 --- /dev/null +++ b/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu @@ -0,0 +1,186 @@ +#include +#include +#include "ATen/ATen.h" +#define MIN_VALUE (-1e38) +typedef at::BFloat16 bf16; + +__global__ void kernel_forward_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +__global__ void kernel_forward_with_state_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y, + float *__restrict__ const _s +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset_s = _b * C * 3 + _c * 3; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + float *__restrict__ const s = _s + _offset_s; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + s[0] = aa; + s[1] = bb; + s[2] = pp; +} + +__global__ void kernel_backward_bf16( + const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u, + const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y, + const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, + bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + const bf16 *__restrict__ const y = _y + _offset; + const bf16 *__restrict__ const gy = _gy + _offset; + bf16 *__restrict__ const gk = _gk + _offset; + bf16 *__restrict__ const gv = _gv + _offset; + + float q[Tmax], r[Tmax]; + + float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + const float qq = float(gy[ii]) / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = bf16(gu); + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + const float qq = q[i]; + const float rr = r[i]; + + float e1 = qq * exp(rr); + float e2 = exp(kk + pp); + gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); + gv[ii] = bf16(e1 + e2 * aa); + + const float ww = w + pp; + const float www = rr - u - kk; + const float p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_bf16<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward_with_state_bf16<<>>(B, T, C, w, u, k, v, y, s); +} + +void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward_bf16<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/src/transformers/kernels/rwkv/wkv_op.cpp b/src/transformers/kernels/rwkv/wkv_op.cpp new file mode 100644 index 00000000000000..55e7280665927b --- /dev/null +++ b/src/transformers/kernels/rwkv/wkv_op.cpp @@ -0,0 +1,66 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); +void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s); +void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); +void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); + +void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_with_state(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); +} +void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_forward_with_state_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); +} +void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} +void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + const int B = k.size(0); + const int T = k.size(1); + const int C = k.size(2); + cuda_backward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), + gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("forward_bf16", &forward_bf16, "wkv forward bf16"); + m.def("forward_with_state", &forward_with_state, "wkv forward with state"); + m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16"); + m.def("backward", &backward, "wkv backward"); + m.def("backward_bf16", &backward_bf16, "wkv backward bf16"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("forward_bf16", forward_bf16); + m.def("forward_with_state", forward_with_state); + m.def("forward_with_state_bf16", forward_with_state_bf16); + m.def("backward", backward); + m.def("backward_bf16", backward_bf16); +} diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5b06d7a0311ed3..261c831afee354 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -162,6 +162,7 @@ roc_bert, roformer, sam, + rwkv, segformer, sew, sew_d, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 95e2e5ab11f82c..b17e812dcce523 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -162,7 +162,11 @@ ("roberta-prelayernorm", "RobertaPreLayerNormConfig"), ("roc_bert", "RoCBertConfig"), ("roformer", "RoFormerConfig"), +<<<<<<< HEAD ("sam", "SamConfig"), +======= + ("rwkv", "RwkvConfig"), +>>>>>>> First draft of RWKV-4 ("segformer", "SegformerConfig"), ("sew", "SEWConfig"), ("sew-d", "SEWDConfig"), @@ -341,7 +345,11 @@ ("roberta-prelayernorm", "ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roc_bert", "ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), +<<<<<<< HEAD ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), +======= + ("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"), +>>>>>>> First draft of RWKV-4 ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -543,6 +551,7 @@ ("roc_bert", "RoCBert"), ("roformer", "RoFormer"), ("sam", "SAM"), + ("rwkv", "RWKV"), ("segformer", "SegFormer"), ("sew", "SEW"), ("sew-d", "SEW-D"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7e6d3b38baf3b4..57fd9412b03429 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -158,6 +158,7 @@ ("roc_bert", "RoCBertModel"), ("roformer", "RoFormerModel"), ("sam", "SamModel"), + ("rwkv", "RwkvModel"), ("segformer", "SegformerModel"), ("sew", "SEWModel"), ("sew-d", "SEWDModel"), @@ -247,6 +248,7 @@ ("roberta", "RobertaForMaskedLM"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roc_bert", "RoCBertForPreTraining"), + ("rwkv", "RwkvForCausalLM"), ("splinter", "SplinterForPreTraining"), ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -331,6 +333,7 @@ ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roc_bert", "RoCBertForMaskedLM"), ("roformer", "RoFormerForMaskedLM"), + ("rwkv", "RwkvForCausalLM"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -393,6 +396,7 @@ ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"), ("roc_bert", "RoCBertForCausalLM"), ("roformer", "RoFormerForCausalLM"), + ("rwkv", "RwkvForCausalLM"), ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index de954e206ae194..378699a8178fcf 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -297,6 +297,7 @@ ), ("roc_bert", ("RoCBertTokenizer", None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), + ("rwkv", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/rwkv/__init__.py b/src/transformers/models/rwkv/__init__.py new file mode 100644 index 00000000000000..f1564d55c864d4 --- /dev/null +++ b/src/transformers/models/rwkv/__init__.py @@ -0,0 +1,69 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_keras_nlp_available, + is_tensorflow_text_available, + is_torch_available, +) + + +_import_structure = { + "configuration_rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig", "RwkvOnnxConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_rwkv"] = [ + "RWKV_PRETRAINED_MODEL_ARCHIVE_LIST", + "RwkvForCausalLM", + "RwkvModel", + "RwkvPreTrainedModel", + ] + +try: + if not is_keras_nlp_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_rwkv_tf"] = ["TFGPT2Tokenizer"] + +if TYPE_CHECKING: + from .configuration_rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig, RwkvOnnxConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_rwkv import ( + RWKV_PRETRAINED_MODEL_ARCHIVE_LIST, + RwkvForCausalLM, + RwkvModel, + RwkvPreTrainedModel, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py new file mode 100644 index 00000000000000..0a938f976708e3 --- /dev/null +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" RWKV configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "sgugger/rwkv-4-pile-7b": "https://huggingface.co/sgugger/rwkv-4-pile-7b/resolve/main/config.json", +} + + +class RwkvConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the RWVK-4 + [sgugger/rwkv-4-pile-7b](https://huggingface.co/sgugger/rwkv-4-pile-7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50277): + Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RwkvModel`]. + context_length (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can be be used with in a single forward (using it in RNN mode + lets use any sequence length). + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + attention_hidden_size (`int`, *optional*): + Dimensionality of the attention hidden states. Will default to `hidden_size` if unset. + intermediate_size (`int`, *optional*): + Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + rescale_every (`int`, *optional*, default to 6): + At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every + `rescale_every` layer. If set to 0 or a negative number, no rescale is done. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last state. + + Example: + + ```python + >>> from transformers import RwkvConfig, RwkvModel + + >>> # Initializing a Rwkv configuration + >>> configuration = RwkvConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RwkvModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rwkv" + attribute_map = { + "max_position_embeddings": "context_length", + } + + def __init__( + self, + vocab_size=50277, + context_length=1024, + hidden_size=4096, + num_hidden_layers=32, + attention_hidden_size=None, + intermediate_size=None, + layer_norm_epsilon=1e-5, + # initializer_range=0.02, + # bos_token_id=50256, + # eos_token_id=50256, + rescale_every=6, + tie_word_embeddings=False, + use_cache=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.context_length = context_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attention_hidden_size = attention_hidden_size + self.intermediate_size = intermediate_size + self.layer_norm_epsilon = layer_norm_epsilon + self.rescale_every = rescale_every + self.use_cache = use_cache + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py new file mode 100644 index 00000000000000..37561cd3716752 --- /dev/null +++ b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert a RKWV checkpoint form BlinkDL to the Hugging Face format.""" + + +import argparse +import gc +import json +import os +import re + +import torch +from huggingface_hub import hf_hub_download + +from transformers import PreTrainedTokenizerFast, RwkvConfig +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint + + +NUM_HIDDEN_LAYERS_MAPPING = { + "169M": 12, + "430M": 24, + "1B5": 24, + "3B": 32, + "7B": 32, + "14B": 40, +} + +HIDEN_SIZE_MAPPING = { + "169M": 768, + "430M": 1024, + "1B5": 2048, + "3B": 2560, + "7B": 4096, + "14B": 5120, +} + + +def convert_state_dict(state_dict): + state_dict_keys = list(state_dict.keys()) + for name in state_dict_keys: + weight = state_dict.pop(name) + # emb -> embedding + if name.startswith("emb."): + name = name.replace("emb.", "embeddings.") + # ln_0 -> pre_ln (only present at block 0) + if name.startswith("blocks.0.ln0"): + name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") + # att -> attention + name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) + # ffn -> feed_forward + name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) + # time_mix_k -> time_mix_key and reshape + if name.endswith(".time_mix_k"): + name = name.replace(".time_mix_k", ".time_mix_key") + # time_mix_v -> time_mix_value and reshape + if name.endswith(".time_mix_v"): + name = name.replace(".time_mix_v", ".time_mix_value") + # time_mix_r -> time_mix_key and reshape + if name.endswith(".time_mix_r"): + name = name.replace(".time_mix_r", ".time_mix_receptance") + + if name != "head.weight": + name = "rwkv." + name + + state_dict[name] = weight + return state_dict + + +def convert_rmkv_checkpoint_to_hf_format(repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None): + # 1. If possible, build the tokenizer. + if tokenizer_file is None: + print("No `--tokenizer_file` provided, we will only convert the model.") + vocab_size = 50277 + else: + tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) + vocab_size = len(tokenizer) + tokenizer.save_pretrained(output_dir) + + # 2. Build the config + possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys()) + if size is None: + # Try to infer size from the checkpoint name + for candidate in possible_sizes: + if candidate in checkpoint_file: + size = candidate + break + if size is None: + raise ValueError("Could not infer the size, please provide it with the `--size` argument.") + if size not in possible_sizes: + raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.") + + config = RwkvConfig( + vocab_size=vocab_size, + num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size], + hidden_size=HIDEN_SIZE_MAPPING[size], + ) + config.save_pretrained(output_dir) + + # 3. Download model file then convert state_dict + model_file = hf_hub_download(repo_id, checkpoint_file) + state_dict = torch.load(model_file, map_location="cpu") + state_dict = convert_state_dict(state_dict) + + # 4. Split in shards and save + shards, index = shard_checkpoint(state_dict) + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is not None: + save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + # 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict + print( + "Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model." + ) + shard_files = list(shards.keys()) + + del state_dict + del shards + gc.collect() + + for shard_file in shard_files: + state_dict = torch.load(shard_file) + torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, shard_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint." + ) + parser.add_argument( + "--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo." + ) + parser.add_argument( + "--output_dir", default=None, type=str, required=True, help="Where to save the converted model." + ) + parser.add_argument( + "--tokenizer_file", + default=None, + type=str, + help="Path to the tokenizer file to use (if not provided, only the model is converted).", + ) + parser.add_argument( + "--size", + default=None, + type=str, + help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.", + ) + + args = parser.parse_args() + convert_rmkv_checkpoint_to_hf_format( + args.repo_id, args.checkpoint_file, args.output_dir, size=args.size, tokenizer_file=args.tokenizer_file + ) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py new file mode 100644 index 00000000000000..1d0f9b35a8f69d --- /dev/null +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -0,0 +1,747 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RWKV model.""" + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_ninja_available, + is_torch_cuda_available, + logging, +) +from .configuration_rwkv import RwkvConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "sgugger/rwkv-4-pile-7b" +_CONFIG_FOR_DOC = "RwkvConfig" + +RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "sgugger/rwkv-4-pile-7b", + # See all RWKV models at https://huggingface.co/models?filter=rwkv +] + + +rwkv_cuda_kernel = None + + +def load_wkv_cuda_kernel(context_length): + from torch.utils.cpp_extension import load as load_kernel + + global rwkv_cuda_kernel + + kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv" + cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]] + + # Only load the kernel if it's not been loaded yet or if we changed the context length + if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length: + return + + logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.") + + flags = [ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={context_length}", + ] + rwkv_cuda_kernel = load_kernel( + name=f"wkv_{context_length}", + sources=cuda_kernel_files, + verbose=(logging.get_verbosity() == logging.DEBUG), + extra_cuda_cflags=flags, + ) + rwkv_cuda_kernel.max_seq_length = context_length + + +class RwkvLinearAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False): + batch_size, seq_len, hidden_size = key.size() + if seq_len > rwkv_cuda_kernel.max_seq_length: + raise ValueError( + f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of " + f"{rwkv_cuda_kernel.max_seq_length} with this model." + ) + if batch_size * hidden_size % min(hidden_size, 32) != 0: + raise ValueError( + f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round " + f"multiple of {min(hidden_size, 32)}." + ) + + ctx.input_dtype = key.dtype + + if ( + time_decay.device.type != "cuda" + or time_first.device.type != "cuda" + or key.device.type != "cuda" + or value.device.type != "cuda" + ): + raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.") + + time_decay = -torch.exp(time_decay.float().contiguous()) + if key.dtype == torch.float16: + time_first = time_first.float() + key = key.float() + value = value.float() + time_first = time_first.contiguous() + key = key.contiguous() + value = value.contiguous() + # The CUDA kernel will fill this tensor. + output = torch.empty_like(key, memory_format=torch.contiguous_format) + if return_state or state is not None: + if state is None: + state = torch.empty( + batch_size, + hidden_size, + 3, + dtype=torch.float32, + device=key.device, + memory_format=torch.contiguous_format, + ) + else: + state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous() + if key.dtype == torch.bfloat16: + forward_func = rwkv_cuda_kernel.forward_with_state_bf16 + else: + forward_func = rwkv_cuda_kernel.forward_with_state + # TODO: update CUDA kernel so it uses the initial state provided here. + forward_func(time_decay, time_first, key, value, output, state) + else: + forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward + forward_func(time_decay, time_first, key, value, output) + + ctx.save_for_backward(time_decay, time_first, key, value, output) + + if state is not None: + state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)] + + return output.to(ctx.input_dtype), state + + @staticmethod + # g stands for grad + def backward(ctx, g_output): + input_dtype = ctx.input_dtype + + time_decay, time_first, key, value, output = ctx.saved_tensors + # The CUDA kernel will fill those tensors. + g_time_decay = torch.empty_like( + time_decay, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32, + ) + g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format) + g_key = torch.empty_like(key, memory_format=torch.contiguous_format) + g_value = torch.empty_like(value, memory_format=torch.contiguous_format) + + if input_dtype == torch.float16: + g_output = g_output.float() + backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward + backward_func( + time_decay, + time_first, + key, + value, + output, + g_output.contiguous(), + g_time_decay, + g_time_first, + g_key, + g_value, + ) + g_time_decay = torch.sum(g_time_decay, dim=0) + g_time_first = torch.sum(g_time_first, dim=0) + + return ( + None, + None, + None, + g_time_decay.to(input_dtype), + g_time_first.to(input_dtype), + g_key.to(input_dtype), + g_value.to(input_dtype), + ) + + +def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False): + # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed + # within a torch.no_grad. + _, seq_length, _ = key.size() + output = torch.zeros_like(key) + + if state is None: + num_state = torch.zeros_like(key[:, 0], dtype=torch.float32) + den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) + max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 + else: + num_state, den_state, max_state = state + # For numerical stability + # real_numerator_state = num_state * torch.exp(max_state) + # real_denominator_state = den_state * torch.exp(max_state) + + time_decay = -torch.exp(time_decay) + + for t in range(seq_length): + current_key = key[:, t].float() + current_value = value[:, t] + + # wkv computation at time t + max_for_output = torch.maximum(max_state, current_key + time_first) + e1 = torch.exp(max_state - max_for_output) + e2 = torch.exp(current_key + time_first - max_for_output) + numerator = e1 * num_state + e2 * current_value + denominator = e1 * den_state + e2 + output[:, t] = (numerator / denominator).to(output.dtype) + + # Update state for next iteration + max_for_state = torch.maximum(max_state + time_decay, current_key) + e1 = torch.exp(max_state + time_decay - max_for_state) + e2 = torch.exp(current_key - max_for_state) + num_state = e1 * num_state + e2 * current_value + den_state = e1 * den_state + e2 + max_state = max_for_state + + if return_state or state is not None: + state = [num_state, den_state, max_state] + + return output, state + + +def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False): + no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value]) + # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version + # in this case). + one_token = key.size(1) == 1 + if rwkv_cuda_kernel is None or no_cuda or one_token: + return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state) + else: + return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state) + + +class RwkvSelfAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.config = config + kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length + if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded: + try: + load_wkv_cuda_kernel(config.context_length) + except Exception: + logger.info("Could not load the custom CUDA kernel for RWKV attention.") + self.layer_id = layer_id + hidden_size = config.hidden_size + attention_hidden_size = ( + config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size + ) + self.attention_hidden_size = attention_hidden_size + + self.time_decay = nn.Parameter(torch.empty(attention_hidden_size)) + self.time_first = nn.Parameter(torch.empty(attention_hidden_size)) + + self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False) + + # TODO: maybe jit, otherwise move inside forward + def extract_key_value(self, hidden, state=None): + # Mix hidden with the previous timestep to produce key, value, receptance + shifted = self.time_shift(hidden) if state is None or hidden.size(1) != 1 else state[1][:, :, self.layer_id] + key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) + value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value) + receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) + + key = self.key(key) + value = self.value(value) + receptance = torch.sigmoid(self.receptance(receptance)) + if state is not None: + state[1][:, :, self.layer_id] = hidden[:, -1] + return receptance, key, value, state + + def forward(self, hidden, state=None, use_cache=False): + receptance, key, value, state = self.extract_key_value(hidden, state=state) + layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None + rwkv, layer_state = rwkv_linear_attention( + self.time_decay, + self.time_first, + key, + value, + state=layer_state, + return_state=use_cache, + ) + + if layer_state is not None: + state[2][:, :, self.layer_id] = layer_state[0] + state[3][:, :, self.layer_id] = layer_state[1] + state[4][:, :, self.layer_id] = layer_state[2] + + return self.output(receptance * rwkv), state + + +class RwkvFeedForward(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.config = config + self.layer_id = layer_id + hidden_size = config.hidden_size + intermediate_size = ( + config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size + ) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.time_mix_key = nn.Parameter(torch.empty(hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(hidden_size)) + + self.key = nn.Linear(hidden_size, intermediate_size, bias=False) + self.receptance = nn.Linear(hidden_size, hidden_size, bias=False) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, hidden, state=None): + shifted = self.time_shift(hidden) if state is None or hidden.size(1) != 1 else state[0][:, :, self.layer_id] + key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) + receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) + + key = torch.square(torch.relu(self.key(key))) + value = self.value(key) + receptance = torch.sigmoid(self.receptance(receptance)) + + if state is not None: + state[0][:, :, self.layer_id] = hidden[:, -1] + + return receptance * value, state + + +class RwkvBlock(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.config = config + self.layer_id = layer_id + + if layer_id == 0: + self.pre_ln = nn.LayerNorm(config.hidden_size) + + self.ln1 = nn.LayerNorm(config.hidden_size) + self.ln2 = nn.LayerNorm(config.hidden_size) + + self.attention = RwkvSelfAttention(config, layer_id) + self.feed_forward = RwkvFeedForward(config, layer_id) + + def forward(self, hidden, state=None, use_cache=False): + if self.layer_id == 0: + hidden = self.pre_ln(hidden) + + attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache) + hidden = hidden + attention + + feed_forward, state = self.feed_forward(self.ln2(hidden), state=state) + hidden = hidden + feed_forward + return hidden, state + + +class RwkvPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RwkvConfig + base_model_prefix = "rwkv" + # supports_gradient_checkpointing = True + _no_split_modules = ["RwkvBlock"] + _keep_in_fp32_modules = ["time_decay", "time_first"] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, RwkvSelfAttention): + layer_id = module.layer_id + num_hidden_layers = module.config.num_hidden_layers + hidden_size = module.config.hidden_size + attention_hidden_size = module.attention_hidden_size + + ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + + ddd = torch.tensor( + [i / hidden_size for i in range(hidden_size)], + dtype=module.time_mix_key.dtype, + device=module.time_mix_key.device, + ) + ddd = ddd[None, None, :] + + decay_speed = [ + -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + for h in range(attention_hidden_size) + ] + decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device) + zigzag = ( + torch.tensor( + [(i + 1) % 3 - 1 for i in range(attention_hidden_size)], + dtype=module.time_first.dtype, + device=module.time_first.device, + ) + * 0.5 + ) + + with torch.no_grad(): + module.time_decay.data = decay_speed + module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag) + + module.time_mix_key.data = torch.pow(ddd, ratio_1_to_almost0) + module.time_mix_value.data = torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + module.time_mix_receptance.data = torch.pow(ddd, 0.5 * ratio_1_to_almost0) + elif isinstance(module, RwkvFeedForward): + layer_id = module.layer_id + num_hidden_layers = module.config.num_hidden_layers + hidden_size = module.config.hidden_size + + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + + ddd = torch.tensor( + [i / hidden_size for i in range(hidden_size)], + dtype=module.time_mix_key.dtype, + device=module.time_mix_key.device, + ) + ddd = ddd[None, None, :] + + with torch.no_grad(): + module.time_mix_key.data = torch.pow(ddd, ratio_1_to_almost0) + module.time_mix_receptance.data = torch.pow(ddd, ratio_1_to_almost0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RwkvModel): + module.gradient_checkpointing = value + + +@dataclass +class RwkvOutput(ModelOutput): + """ + Class for the RWKV model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + state: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RwkvCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + state: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +RWKV_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RwkvConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RWKV_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + use_cache (`bool`, *optional*): + If set to `True`, the last state is returned and can be used to quickly generate the next logits. +""" + + +@add_start_docstrings( + "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.", + RWKV_START_DOCSTRING, +) +class RwkvModel(RwkvPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) + self.ln_out = nn.LayerNorm(config.hidden_size) + + self.layers_are_rescaled = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=RwkvOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RwkvOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training == self.layers_are_rescaled: + self._rescale_layers() + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if use_cache and state is None: + shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) + state = [ + torch.zeros(*shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=input_ids.device) + for i in range(5) + ] + state[4] -= 1e30 + + hidden_states = inputs_embeds + + # TODO: add this. + # all_self_attentions = () if output_attentions else None + # all_hidden_states = () if output_hidden_states else None + for idx, block in enumerate(self.blocks): + hidden_states, state = block(hidden_states, state=state, use_cache=use_cache) + if ( + self.layers_are_rescaled + and self.config.rescale_every > 0 + and (idx + 1) % self.config.rescale_every == 0 + ): + hidden_states = hidden_states / 2 + + hidden_states = self.ln_out(hidden_states) + + if not return_dict: + (hidden_states,) + + return RwkvOutput(last_hidden_state=hidden_states, state=state, hidden_states=None, attentions=None) + + def _rescale_layers(self): + # Layers should be rescaled for inference only. + if self.layers_are_rescaled == (not self.training): + return + if self.config.rescale_every > 0: + with torch.no_grad(): + for block_id, block in enumerate(self.blocks): + if self.training: + block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every)) + else: + block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) + + self.layers_are_rescaled = not self.training + + +@add_start_docstrings( + """ + The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + RWKV_START_DOCSTRING, +) +class RwkvForCausalLM(RwkvPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.rwkv = RwkvModel(config) + self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.head + + def set_output_embeddings(self, new_embeddings): + self.head = new_embeddings + + @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=RwkvCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RwkvCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + rwkv_outputs = self.rwkv( + input_ids, + inputs_embeds=inputs_embeds, + state=state, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = rwkv_outputs[0] + + logits = self.head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + rwkv_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return RwkvCausalLMOutput( + loss=loss, + logits=logits, + state=rwkv_outputs.state, + hidden_states=rwkv_outputs.hidden_states, + attentions=rwkv_outputs.attentions, + ) diff --git a/tests/models/rwkv/__init__.py b/tests/models/rwkv/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py new file mode 100644 index 00000000000000..a9f843dc106c09 --- /dev/null +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -0,0 +1,281 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import RwkvConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + from transformers import ( + RWKV_PRETRAINED_MODEL_ARCHIVE_LIST, + RwkvForCausalLM, + RwkvModel, + ) + + +class RwkvModelTester: + def __init__( + self, + parent, + batch_size=14, + seq_length=7, + is_training=True, + use_token_type_ids=True, + use_input_mask=True, + use_labels=True, + use_mc_token_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = None + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + + def get_large_model_config(self): + return RwkvConfig.from_pretrained("sgugger/rwkv-4-pile-7b") + + def prepare_config_and_inputs( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config( + gradient_checkpointing=gradient_checkpointing, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + + head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + return RwkvConfig( + vocab_size=self.vocab_size, + n_embd=self.hidden_size, + n_layer=self.num_hidden_layers, + n_head=self.num_attention_heads, + n_inner=self.intermediate_size, + activation_function=self.hidden_act, + resid_pdrop=self.hidden_dropout_prob, + attn_pdrop=self.attention_probs_dropout_prob, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + use_cache=True, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + + def get_pipeline_config(self): + config = self.get_config() + config.vocab_size = 300 + return config + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = RwkvModel(config=config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(len(result.past_key_values), config.n_layer) + + def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = RwkvForCausalLM(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): + model = RwkvForCausalLM(config) + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "head_mask": head_mask, + } + + return config, inputs_dict + + +@require_torch +class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else () + # all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else () + fx_compatible = False + test_missing_keys = False + test_model_parallel = True + + def setUp(self): + self.model_tester = RwkvModelTester(self) + self.config_tester = ConfigTester(self, config_class=RwkvConfig, n_embd=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_rwkv_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_rwkv_model(*config_and_inputs) + + def test_rwkv_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causl_lm(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = RwkvModel.from_pretrained(model_name) + self.assertIsNotNone(model) From db50d8b889d045b415900ddd42d39103980112bc Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Mon, 17 Apr 2023 10:22:04 -0400 Subject: [PATCH 02/27] Add support for generate --- src/transformers/generation/utils.py | 2 ++ src/transformers/models/rwkv/modeling_rwkv.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 4a0e621de14c14..a68b5acf66d413 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -753,6 +753,8 @@ def _update_model_kwargs_for_generation( model_kwargs["past_key_values"] = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state # update token_type_ids with last value if "token_type_ids" in model_kwargs: diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 1d0f9b35a8f69d..ecee1402435b21 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -685,6 +685,20 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.head = new_embeddings + def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): + # only last token for inputs_ids if the state is passed along. + if state is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and state is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs["state"] = state + return model_inputs + @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, From b6273e9e3c827c28c297aed182de0705feeec1cd Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 27 Apr 2023 15:39:16 -0400 Subject: [PATCH 03/27] Style post-rebase --- src/transformers/__init__.py | 31 +++++++++---------- src/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 12 ++----- src/transformers/models/auto/modeling_auto.py | 2 +- 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9fd080a5fb5672..c45ca1bff10a53 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -430,6 +430,7 @@ "models.roberta_prelayernorm": ["ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaPreLayerNormConfig"], "models.roc_bert": ["ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoCBertConfig", "RoCBertTokenizer"], "models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"], + "models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"], "models.sam": [ "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP", "SamConfig", @@ -438,7 +439,6 @@ "SamPromptEncoderConfig", "SamVisionConfig", ], - "models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"], "models.segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"], "models.sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"], "models.sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"], @@ -2356,11 +2356,6 @@ "load_tf_weights_in_roformer", ] ) - _import_structure["models.sam"].extend( - [ - "SAM_PRETRAINED_MODEL_ARCHIVE_LIST", - "SamModel", - "SamPreTrainedModel", _import_structure["models.rwkv"].extend( [ "RWKV_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2369,6 +2364,13 @@ "RwkvPreTrainedModel", ] ) + _import_structure["models.sam"].extend( + [ + "SAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "SamModel", + "SamPreTrainedModel", + ] + ) _import_structure["models.segformer"].extend( [ "SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4165,7 +4167,7 @@ ) from .models.roc_bert import ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RoCBertConfig, RoCBertTokenizer from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer -<<<<<<< HEAD + from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig from .models.sam import ( SAM_PRETRAINED_CONFIG_ARCHIVE_MAP, SamConfig, @@ -4174,9 +4176,6 @@ SamPromptEncoderConfig, SamVisionConfig, ) -======= - from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig ->>>>>>> First draft of RWKV-4 from .models.segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig from .models.sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig from .models.sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig @@ -5772,18 +5771,16 @@ RoFormerPreTrainedModel, load_tf_weights_in_roformer, ) -<<<<<<< HEAD - from .models.sam import ( - SAM_PRETRAINED_MODEL_ARCHIVE_LIST, - SamModel, - SamPreTrainedModel, -======= from .models.rwkv import ( RWKV_PRETRAINED_MODEL_ARCHIVE_LIST, RwkvForCausalLM, RwkvModel, RwkvPreTrainedModel, ->>>>>>> First draft of RWKV-4 + ) + from .models.sam import ( + SAM_PRETRAINED_MODEL_ARCHIVE_LIST, + SamModel, + SamPreTrainedModel, ) from .models.segformer import ( SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 261c831afee354..5d7106b3874242 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -161,8 +161,8 @@ roberta_prelayernorm, roc_bert, roformer, - sam, rwkv, + sam, segformer, sew, sew_d, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b17e812dcce523..640e1cd44aef5b 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -162,11 +162,8 @@ ("roberta-prelayernorm", "RobertaPreLayerNormConfig"), ("roc_bert", "RoCBertConfig"), ("roformer", "RoFormerConfig"), -<<<<<<< HEAD - ("sam", "SamConfig"), -======= ("rwkv", "RwkvConfig"), ->>>>>>> First draft of RWKV-4 + ("sam", "SamConfig"), ("segformer", "SegformerConfig"), ("sew", "SEWConfig"), ("sew-d", "SEWDConfig"), @@ -345,11 +342,8 @@ ("roberta-prelayernorm", "ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roc_bert", "ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), -<<<<<<< HEAD - ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), -======= ("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"), ->>>>>>> First draft of RWKV-4 + ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -550,8 +544,8 @@ ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"), ("roc_bert", "RoCBert"), ("roformer", "RoFormer"), - ("sam", "SAM"), ("rwkv", "RWKV"), + ("sam", "SAM"), ("segformer", "SegFormer"), ("sew", "SEW"), ("sew-d", "SEW-D"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 57fd9412b03429..9d094b9d044d8e 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -157,8 +157,8 @@ ("roberta-prelayernorm", "RobertaPreLayerNormModel"), ("roc_bert", "RoCBertModel"), ("roformer", "RoFormerModel"), - ("sam", "SamModel"), ("rwkv", "RwkvModel"), + ("sam", "SamModel"), ("segformer", "SegformerModel"), ("sew", "SEWModel"), ("sew-d", "SEWDModel"), From 7b819307240c21cccff66d8f1758627eff9ee02d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 27 Apr 2023 17:03:21 -0400 Subject: [PATCH 04/27] Properly use state --- src/transformers/kernels/rwkv/wkv_cuda.cu | 2 +- .../kernels/rwkv/wkv_cuda_bf16.cu | 2 +- src/transformers/models/rwkv/modeling_rwkv.py | 21 ++++++++++++++----- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/transformers/kernels/rwkv/wkv_cuda.cu b/src/transformers/kernels/rwkv/wkv_cuda.cu index 8337c2eacd1b3b..571d5a8a8307e9 100644 --- a/src/transformers/kernels/rwkv/wkv_cuda.cu +++ b/src/transformers/kernels/rwkv/wkv_cuda.cu @@ -61,7 +61,7 @@ __global__ void kernel_forward_with_state( F *__restrict__ const s = _s + _offset_s; // aa and bb are running sums divided by exp(pp) (to avoid overflow) - F aa = 0, bb = 0, pp = MIN_VALUE; + F aa = s[0], bb = s[1], pp = s[2]; for (int i = 0; i < T; i++) { const int ii = i * C; const F kk = k[ii]; diff --git a/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu b/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu index d9ff0f427b6193..042cb4aba1db98 100644 --- a/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu +++ b/src/transformers/kernels/rwkv/wkv_cuda_bf16.cu @@ -61,7 +61,7 @@ __global__ void kernel_forward_with_state_bf16( float *__restrict__ const s = _s + _offset_s; // aa and bb are running sums divided by exp(pp) (to avoid overflow) - float aa = 0, bb = 0, pp = MIN_VALUE; + float aa = s[0], bb = s[1], pp = s[2]; for (int i = 0; i < T; i++) { const int ii = i * C; const float kk = float(k[ii]); diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index ecee1402435b21..73d82b78753a62 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -121,7 +121,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa output = torch.empty_like(key, memory_format=torch.contiguous_format) if return_state or state is not None: if state is None: - state = torch.empty( + state = torch.zeros( batch_size, hidden_size, 3, @@ -129,6 +129,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa device=key.device, memory_format=torch.contiguous_format, ) + state[:, :, 2] -= 1e38 else: state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous() if key.dtype == torch.bfloat16: @@ -281,7 +282,12 @@ def __init__(self, config, layer_id=0): # TODO: maybe jit, otherwise move inside forward def extract_key_value(self, hidden, state=None): # Mix hidden with the previous timestep to produce key, value, receptance - shifted = self.time_shift(hidden) if state is None or hidden.size(1) != 1 else state[1][:, :, self.layer_id] + if hidden.size(1) == 1 and state is not None: + shifted = state[1][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[1][:, :, self.layer_id] key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value) receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) @@ -324,15 +330,20 @@ def __init__(self, config, layer_id=0): ) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.time_mix_key = nn.Parameter(torch.empty(hidden_size)) - self.time_mix_receptance = nn.Parameter(torch.empty(hidden_size)) + self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) self.key = nn.Linear(hidden_size, intermediate_size, bias=False) self.receptance = nn.Linear(hidden_size, hidden_size, bias=False) self.value = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, hidden, state=None): - shifted = self.time_shift(hidden) if state is None or hidden.size(1) != 1 else state[0][:, :, self.layer_id] + if hidden.size(1) == 1 and state is not None: + shifted = state[0][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[0][:, :, self.layer_id] key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) From 400147ec52b4bc99f58b7d047a4b737e6834a038 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 27 Apr 2023 20:13:00 -0400 Subject: [PATCH 05/27] Write doc --- docs/source/en/model_doc/rwkv.mdx | 94 ++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/rwkv.mdx b/docs/source/en/model_doc/rwkv.mdx index 71c07b71810afa..4c75bd9e508087 100644 --- a/docs/source/en/model_doc/rwkv.mdx +++ b/docs/source/en/model_doc/rwkv.mdx @@ -16,14 +16,36 @@ specific language governing permissions and limitations under the License. The RWKV model was proposed in [this repo](https://github.com/BlinkDL/RWKV-LM) -TODO, need to write this page - -Tips: +It suggests a tweak in the traditional Transformer attention to make it linear. This way, the model can be used as recurrent network: passing inputs for timestamp 0 and timestamp 1 together is the same as passing inputs at timestamp 0, then inputs at timestamp 1 along with the state of timestamp 0 (see example below). +This can be more efficient than a regular Transformer and can deal with sentence of any length (even if the model uses a fixed context length for training). This model was contributed by [sgugger](https://huggingface.co/sgugger). The original code can be found [here](https://github.com/BlinkDL/RWKV-LM). +Example of use as an RNN: + +```py +import torch +from transformers import AutoTokenizer, RwkvConfig, RwkvModel + +model = RwkvModel.from_pretrained("sgugger/rwkv-430M-pile") +tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-430M-pile") + +inputs = tokenizer("This is an example.", return_tensors="pt") +# Feed everything to the model +outputs = model(inputs["input_ids"]) +output_whole = outputs.last_hidden_state + +outputs = model(inputs["input_ids"][:, :2]) +output_one = outputs.last_hidden_state + +# Using the state computed on the first inputs, we will get the same output +outputs = model(inputs["input_ids"][:, 2:], state=outputs.state) +output_two = outputs.last_hidden_state + +torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5) +``` ## RwkvConfig @@ -41,3 +63,69 @@ The original code can be found [here](https://github.com/BlinkDL/RWKV-LM). [[autodoc]] RwkvForCausalLM - forward - forward_rnn + +## Rwkv attention and the recurrent formulas + +In a traditional auto-regressive Transformer, attention is written as + +$$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$ + +with `Q`, `K` and `V` are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with `V` to get the output `O` of the same shape as the others. + +Replacing the softmax by its value gives: + +$$O_{i} = \frac{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}} V_{j}}{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}}}$$ + +Note that the entries in \\(QK^{T}\\) corresponding to \\(j > i\\) are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones). + +In comparison, the RWKV attention is given by + +$$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$ + +where `R` is a new matrix called receptance by the author, `K` and `V` are still the key and value (\\(\sigma\\) here is the sigmoid function). `W` is a new vector that represents the position of the token and is given by + +$$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$ + +with `u` and `w` learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have: + +$$N_{i} = e^{u + K_{i}} V_{i} + Ns_{i} \hbox{ where } Ns_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$ + +so \\(Ns_{i}\\) satistfies + +$$Ns_{0} = 0 \hbox{ and } Ns_{j+1} = e^{K_{j}} V_{j} + e^{w} Ns_{j}$$ + +and + +$$D_{i} = e^{u + K_{i}} + Ds_{i} \hbox{ where } Ds_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$ + +so \\(Ds_{i}\\) satistfies + +$$Ds_{0} = 0 \hbox{ and } Ds_{j+1} = e^{K_{j}} + e^{w} Ds_{j}$$ + +The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator: + +$$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$ + +with `M` the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(Ns\\)) and the denominator state (\\(Ds\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use + +$$\tilde{Ns}_{i} = e^{-M_{i}} Ns_{i} \hbox{ and } \tilde{Ds}_{i} = e^{-M_{i}} Ds_{i}$$ + +defined by the following recurrent formulas: + +$$\tilde{Ns}_{0} = 0 \hbox{ and } \tilde{Ns}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{Ns}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ + +and + +$$\tilde{Ds}_{0} = 0 \hbox{ and } \tilde{Ds}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{Ds}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ + +and \\(M_{j+1} = q\\). With those, we can then compute + +$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{Ns}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ + +and + +$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{Ds}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ + +which finally gives us + +$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$ \ No newline at end of file From 8603532437a9b5c93df5bfc11ec67fdd7ed29cc7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 27 Apr 2023 20:18:55 -0400 Subject: [PATCH 06/27] Fix doc --- docs/source/en/model_doc/rwkv.mdx | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/en/model_doc/rwkv.mdx b/docs/source/en/model_doc/rwkv.mdx index 4c75bd9e508087..68e2fd07099cd4 100644 --- a/docs/source/en/model_doc/rwkv.mdx +++ b/docs/source/en/model_doc/rwkv.mdx @@ -56,13 +56,11 @@ torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e [[autodoc]] RwkvModel - forward - - forward_rnn ## RwkvLMHeadModel [[autodoc]] RwkvForCausalLM - forward - - forward_rnn ## Rwkv attention and the recurrent formulas From bc0cd7cbebe23e49697c0908493ece7d205b65e1 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 27 Apr 2023 21:06:05 -0400 Subject: [PATCH 07/27] More math --- docs/source/en/model_doc/rwkv.mdx | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/en/model_doc/rwkv.mdx b/docs/source/en/model_doc/rwkv.mdx index 68e2fd07099cd4..b8812516560f3c 100644 --- a/docs/source/en/model_doc/rwkv.mdx +++ b/docs/source/en/model_doc/rwkv.mdx @@ -68,7 +68,7 @@ In a traditional auto-regressive Transformer, attention is written as $$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$ -with `Q`, `K` and `V` are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with `V` to get the output `O` of the same shape as the others. +with \\(Q\\), \\(K\\) and \\(V\\) are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with \\(V\\) to get the output \\(O\\) of the same shape as the others. Replacing the softmax by its value gives: @@ -80,49 +80,49 @@ In comparison, the RWKV attention is given by $$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$ -where `R` is a new matrix called receptance by the author, `K` and `V` are still the key and value (\\(\sigma\\) here is the sigmoid function). `W` is a new vector that represents the position of the token and is given by +where \\(R\\) is a new matrix called receptance by the author, \\(K\\) and \\(V\\) are still the key and value (\\(\sigma\\) here is the sigmoid function). \\(W\\) is a new vector that represents the position of the token and is given by $$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$ -with `u` and `w` learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have: +with \\(u\\) and \\(w\\) learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have: -$$N_{i} = e^{u + K_{i}} V_{i} + Ns_{i} \hbox{ where } Ns_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$ +$$N_{i} = e^{u + K_{i}} V_{i} + \hat{N}_{i} \hbox{ where } \hat{N}_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$ -so \\(Ns_{i}\\) satistfies +so \\(\hat{N}_{i}\\) (called `numerator_state` in the code) satistfies -$$Ns_{0} = 0 \hbox{ and } Ns_{j+1} = e^{K_{j}} V_{j} + e^{w} Ns_{j}$$ +$$\hat{N}_{0} = 0 \hbox{ and } \hat{N}_{j+1} = e^{K_{j}} V_{j} + e^{w} \hat{N}_{j}$$ and -$$D_{i} = e^{u + K_{i}} + Ds_{i} \hbox{ where } Ds_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$ +$$D_{i} = e^{u + K_{i}} + \hat{D}_{i} \hbox{ where } \hat{D}_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$ -so \\(Ds_{i}\\) satistfies +so \\(\hat{D}_{i}\\) (called `denominator_state` in the code) satistfies -$$Ds_{0} = 0 \hbox{ and } Ds_{j+1} = e^{K_{j}} + e^{w} Ds_{j}$$ +$$\hat{D}_{0} = 0 \hbox{ and } \hat{D}_{j+1} = e^{K_{j}} + e^{w} \hat{D}_{j}$$ The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator: $$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$ -with `M` the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(Ns\\)) and the denominator state (\\(Ds\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use +with \\(M\\) the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(\hat{N}\\)) and the denominator state (\\(\hat{D}\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use -$$\tilde{Ns}_{i} = e^{-M_{i}} Ns_{i} \hbox{ and } \tilde{Ds}_{i} = e^{-M_{i}} Ds_{i}$$ +$$\tilde{N}_{i} = e^{-M_{i}} \hat{N}_{i} \hbox{ and } \tilde{D}_{i} = e^{-M_{i}} \hat{D}_{i}$$ defined by the following recurrent formulas: -$$\tilde{Ns}_{0} = 0 \hbox{ and } \tilde{Ns}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{Ns}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ +$$\tilde{N}_{0} = 0 \hbox{ and } \tilde{N}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{N}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ and -$$\tilde{Ds}_{0} = 0 \hbox{ and } \tilde{Ds}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{Ds}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ +$$\tilde{D}_{0} = 0 \hbox{ and } \tilde{D}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{D}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ and \\(M_{j+1} = q\\). With those, we can then compute -$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{Ns}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ +$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{N}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ and -$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{Ds}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ +$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{D}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ which finally gives us From 773658437e64d52e4fd31b5e6339102da5d9ff20 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 28 Apr 2023 10:07:51 -0400 Subject: [PATCH 08/27] Add model to README, dummies and clean config --- README.md | 1 + README_es.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/index.mdx | 2 ++ docs/source/en/tasks/language_modeling.mdx | 2 +- .../models/rwkv/configuration_rwkv.py | 1 - src/transformers/models/rwkv/modeling_rwkv.py | 6 ++--- src/transformers/utils/dummy_pt_objects.py | 24 +++++++++++++++++++ 12 files changed, 37 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index bdec263bd6c350..7ff099d9b7f34f 100644 --- a/README.md +++ b/README.md @@ -421,6 +421,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Peng Bo), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/README_es.md b/README_es.md index 3cce101e018865..34a197fc68385c 100644 --- a/README_es.md +++ b/README_es.md @@ -409,6 +409,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Peng Bo) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/README_hd.md b/README_hd.md index b758322ec5e019..ad4b4ce0d114b0 100644 --- a/README_hd.md +++ b/README_hd.md @@ -381,6 +381,7 @@ conda install -c huggingface transformers 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (झुईई टेक्नोलॉजी से), साथ में पेपर [रोफॉर्मर: रोटरी पोजिशन एंबेडिंग के साथ एन्हांस्ड ट्रांसफॉर्मर] (https://arxiv.org/pdf/2104.09864v1.pdf) जियानलिन सु और यू लू और शेंगफेंग पैन और बो वेन और युनफेंग लियू द्वारा प्रकाशित। +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Peng Bo से) Peng Bo. द्वाराअनुसंधान पत्र [this repo](https://github.com/BlinkDL/RWKV-LM) के साथ जारी किया गया 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI से) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. द्वाराअनुसंधान पत्र [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) के साथ जारी किया गया 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP से) साथ देने वाला पेपर [भाषण पहचान के लिए अनसुपरवाइज्ड प्री-ट्रेनिंग में परफॉर्मेंस-एफिशिएंसी ट्रेड-ऑफ्स](https ://arxiv.org/abs/2109.06870) फेलिक्स वू, क्वांगयुन किम, जिंग पैन, क्यू हान, किलियन क्यू. वेनबर्गर, योव आर्टज़ी द्वारा। diff --git a/README_ja.md b/README_ja.md index 951324ec045b96..b4ebfc9b70802a 100644 --- a/README_ja.md +++ b/README_ja.md @@ -443,6 +443,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (Facebook から) Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli から公開された研究論文: [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI から) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou から公開された研究論文: [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology から), Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu から公開された研究論文: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Peng Bo から) Peng Bo. から公開された研究論文 [this repo](https://github.com/BlinkDL/RWKV-LM) 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA から) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo から公開された研究論文: [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI から) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. から公開された研究論文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP から) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi から公開された研究論文: [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) diff --git a/README_ko.md b/README_ko.md index 71a153fb8f909d..9ce041a9c0fa8d 100644 --- a/README_ko.md +++ b/README_ko.md @@ -358,6 +358,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (Facebook 에서) Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli 의 [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 논문과 함께 발표했습니다. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI 에서) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 의 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 논문과 함께 발표했습니다. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology 에서) Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 의 a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 논문과 함께 발표했습니다. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Peng Bo 에서 제공)은 Peng Bo.의 [this repo](https://github.com/BlinkDL/RWKV-LM)논문과 함께 발표했습니다. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA 에서) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 의 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 논문과 함께 발표했습니다. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI 에서 제공)은 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.의 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf)논문과 함께 발표했습니다. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP 에서) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 의 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 1814b23ff5e539..7c93f894ec1f1d 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -382,6 +382,7 @@ conda install -c huggingface transformers 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (来自 Facebook) 伴随论文 [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 由 Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli 发布。 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (来自 WeChatAI), 伴随论文 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 由 HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 发布。 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。 +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (来自 Peng Bo) 伴随论文 [this repo](https://github.com/BlinkDL/RWKV-LM) 由 Peng Bo 发布。 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (来自 Meta AI) 伴随论文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 由 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick 发布。 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (来自 ASAPP) 伴随论文 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 由 Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index b82f3afd13de54..81c334832e1bed 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -394,6 +394,7 @@ conda install -c huggingface transformers 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Peng Bo) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 259e1cce4fa50e..4d86aeed0cfe3f 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -195,6 +195,7 @@ The documentation is organized into five sections: 1. **[RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. +1. **[RWKV](model_doc/rwkv)** (from Peng Bo), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. 1. **[SegFormer](model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. @@ -394,6 +395,7 @@ Flax), PyTorch, and/or TensorFlow. | RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ | | RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | +| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ | | SAM | ❌ | ❌ | ✅ | ❌ | ❌ | | SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/tasks/language_modeling.mdx b/docs/source/en/tasks/language_modeling.mdx index b79435b08f317c..c3aba8ff0d4e15 100644 --- a/docs/source/en/tasks/language_modeling.mdx +++ b/docs/source/en/tasks/language_modeling.mdx @@ -34,7 +34,7 @@ Choose one of the following architectures: -[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) +[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index 0a938f976708e3..aebd364415c6bd 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -89,7 +89,6 @@ def __init__( attention_hidden_size=None, intermediate_size=None, layer_norm_epsilon=1e-5, - # initializer_range=0.02, # bos_token_id=50256, # eos_token_id=50256, rescale_every=6, diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 73d82b78753a62..199be63b9211ba 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -364,10 +364,10 @@ def __init__(self, config, layer_id): self.layer_id = layer_id if layer_id == 0: - self.pre_ln = nn.LayerNorm(config.hidden_size) + self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.ln1 = nn.LayerNorm(config.hidden_size) - self.ln2 = nn.LayerNorm(config.hidden_size) + self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attention = RwkvSelfAttention(config, layer_id) self.feed_forward = RwkvFeedForward(config, layer_id) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9a7bf661159afa..aae3d34310d4f6 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6042,6 +6042,30 @@ def load_tf_weights_in_roformer(*args, **kwargs): requires_backends(load_tf_weights_in_roformer, ["torch"]) +RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class RwkvForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RwkvModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RwkvPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None From a6aa93279da6d0a117d1b7b41f775151fe82085d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 28 Apr 2023 10:28:49 -0400 Subject: [PATCH 09/27] Fix init --- src/transformers/models/rwkv/__init__.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/transformers/models/rwkv/__init__.py b/src/transformers/models/rwkv/__init__.py index f1564d55c864d4..e68eefe9f8aaa5 100644 --- a/src/transformers/models/rwkv/__init__.py +++ b/src/transformers/models/rwkv/__init__.py @@ -17,8 +17,6 @@ from ...utils import ( OptionalDependencyNotAvailable, _LazyModule, - is_keras_nlp_available, - is_tensorflow_text_available, is_torch_available, ) @@ -40,13 +38,6 @@ "RwkvPreTrainedModel", ] -try: - if not is_keras_nlp_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["tokenization_rwkv_tf"] = ["TFGPT2Tokenizer"] if TYPE_CHECKING: from .configuration_rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig, RwkvOnnxConfig From a50d49cc209f739dae895d7dacfd1ced5f61c787 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 May 2023 15:27:02 +0000 Subject: [PATCH 10/27] multiple fixes: - fix common tests - fix configuraion default values - add CI test for checking state computation - fix some CI tests --- .../models/rwkv/configuration_rwkv.py | 4 +- tests/models/rwkv/test_modeling_rwkv.py | 62 ++++++++++++++----- tests/test_configuration_common.py | 9 ++- 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index aebd364415c6bd..6789301d19d3cb 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -100,8 +100,8 @@ def __init__( self.context_length = context_length self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers - self.attention_hidden_size = attention_hidden_size - self.intermediate_size = intermediate_size + self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size self.layer_norm_epsilon = layer_norm_epsilon self.rescale_every = rescale_every self.use_cache = use_cache diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index a9f843dc106c09..14faa961a453e7 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -26,6 +26,8 @@ if is_torch_available(): + import torch + from transformers import ( RWKV_PRETRAINED_MODEL_ARCHIVE_LIST, RwkvForCausalLM, @@ -40,7 +42,7 @@ def __init__( batch_size=14, seq_length=7, is_training=True, - use_token_type_ids=True, + use_token_type_ids=False, use_input_mask=True, use_labels=True, use_mc_token_ids=True, @@ -121,13 +123,11 @@ def prepare_config_and_inputs( reorder_and_upcast_attn=reorder_and_upcast_attn, ) - head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) - return ( config, input_ids, input_mask, - head_mask, + None, token_type_ids, mc_token_ids, sequence_labels, @@ -140,10 +140,9 @@ def get_config( ): return RwkvConfig( vocab_size=self.vocab_size, - n_embd=self.hidden_size, - n_layer=self.num_hidden_layers, - n_head=self.num_attention_heads, - n_inner=self.intermediate_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=self.intermediate_size, activation_function=self.hidden_act, resid_pdrop=self.hidden_dropout_prob, attn_pdrop=self.attention_probs_dropout_prob, @@ -198,22 +197,37 @@ def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, model.to(torch_device) model.eval() - result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) - result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - self.parent.assertEqual(len(result.past_key_values), config.n_layer) + self.parent.assertEqual(len(result.hidden_states), config.n_layer) def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = RwkvForCausalLM(config) model.to(torch_device) model.eval() - result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + result = model(input_ids, labels=input_ids) self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = RwkvModel(config=config) + model.to(torch_device) + model.eval() + + outputs = model(input_ids) + output_whole = outputs.last_hidden_state + + outputs = model(input_ids[:, :2]) + output_one = outputs.last_hidden_state + + # Using the state computed on the first inputs, we will get the same output + outputs = model(input_ids[:, 2:], state=outputs.state) + output_two = outputs.last_hidden_state + + self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) + def create_and_check_forward_and_backwards( self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False ): @@ -222,7 +236,7 @@ def create_and_check_forward_and_backwards( if gradient_checkpointing: model.gradient_checkpointing_enable() - result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + result = model(input_ids, labels=input_ids) self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) result.loss.backward() @@ -244,8 +258,6 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = { "input_ids": input_ids, - "token_type_ids": token_type_ids, - "head_mask": head_mask, } return config, inputs_dict @@ -254,14 +266,26 @@ def prepare_config_and_inputs_for_common(self): @require_torch class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": RwkvModel, + "text-generation": RwkvForCausalLM, + } + if is_torch_available() + else {} + ) # all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else () fx_compatible = False test_missing_keys = False - test_model_parallel = True + test_model_parallel = False + test_pruning = False + test_head_masking = False # Rwkv does not support head masking def setUp(self): self.model_tester = RwkvModelTester(self) - self.config_tester = ConfigTester(self, config_class=RwkvConfig, n_embd=37) + self.config_tester = ConfigTester( + self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] + ) def test_config(self): self.config_tester.run_common_tests() @@ -274,6 +298,10 @@ def test_rwkv_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_causl_lm(*config_and_inputs) + def test_state_equivalency(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_state_equivalency(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 1d249c3b52ddc3..fdb679529d2b41 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -93,15 +93,20 @@ class ConfigTester(object): - def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs): + def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs): self.parent = parent self.config_class = config_class self.has_text_modality = has_text_modality self.inputs_dict = kwargs + self.common_properties = common_properties def create_and_test_config_common_properties(self): config = self.config_class(**self.inputs_dict) - common_properties = ["hidden_size", "num_attention_heads", "num_hidden_layers"] + common_properties = ( + ["hidden_size", "num_attention_heads", "num_hidden_layers"] + if self.common_properties is None + else self.common_properties + ) # Add common fields for text models if self.has_text_modality: From 848ccf81c62f2d8258876abea79c5555ff7eabc6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 2 May 2023 16:12:17 +0000 Subject: [PATCH 11/27] correct tokenizer --- src/transformers/models/auto/tokenization_auto.py | 2 +- src/transformers/models/rwkv/configuration_rwkv.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 378699a8178fcf..82c3152cd787d4 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -297,7 +297,7 @@ ), ("roc_bert", ("RoCBertTokenizer", None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), - ("rwkv", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("rwkv", (None, "GPTNeoXTokenizerFasts" if is_tokenizers_available() else None)), ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index 6789301d19d3cb..ff03a62feae41a 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -89,8 +89,8 @@ def __init__( attention_hidden_size=None, intermediate_size=None, layer_norm_epsilon=1e-5, - # bos_token_id=50256, - # eos_token_id=50256, + bos_token_id=0, + eos_token_id=2, rescale_every=6, tie_word_embeddings=False, use_cache=True, @@ -106,4 +106,7 @@ def __init__( self.rescale_every = rescale_every self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) From b679baa840b3b015329631c9e7a4b65f256e96b1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 3 May 2023 09:22:43 +0000 Subject: [PATCH 12/27] some tweaks - fix config docstring - fix failing tests --- .../models/rwkv/configuration_rwkv.py | 10 +++++++++- tests/models/rwkv/test_modeling_rwkv.py | 20 +++++++++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index ff03a62feae41a..35d5ddcb3a395a 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -54,6 +54,12 @@ class RwkvConfig(PretrainedConfig): Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset. layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer + as GPTNeoX. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token in the vocabulary. Defaults to 2 as RWKV uses the same tokenizer as + GPTNeoX. rescale_every (`int`, *optional*, default to 6): At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every `rescale_every` layer. If set to 0 or a negative number, no rescale is done. @@ -109,4 +115,6 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + super().__init__( + tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs + ) diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index 14faa961a453e7..648cf08574c176 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -49,7 +49,6 @@ def __init__( vocab_size=99, hidden_size=32, num_hidden_layers=5, - num_attention_heads=4, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, @@ -57,7 +56,6 @@ def __init__( max_position_embeddings=512, type_vocab_size=16, type_sequence_label_size=2, - initializer_range=0.02, num_labels=3, num_choices=4, scope=None, @@ -73,7 +71,6 @@ def __init__( self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob @@ -81,10 +78,9 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices - self.scope = None + self.scope = scope self.bos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 @@ -148,7 +144,6 @@ def get_config( attn_pdrop=self.attention_probs_dropout_prob, n_positions=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, - initializer_range=self.initializer_range, use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, @@ -287,6 +282,19 @@ def setUp(self): self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + def test_config(self): self.config_tester.run_common_tests() From 10b5b81dde44e73b569c90d5f2f35b5bbe5edced Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 3 May 2023 11:56:20 +0000 Subject: [PATCH 13/27] fix CI tests - add output_attention / output_hidden_states - override test_initialization - fix failing CIs --- .../models/auto/tokenization_auto.py | 2 +- src/transformers/models/rwkv/modeling_rwkv.py | 42 ++++-- tests/models/rwkv/test_modeling_rwkv.py | 132 ++++++++++++++++-- 3 files changed, 154 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 82c3152cd787d4..cb6c91521de91b 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -297,7 +297,7 @@ ), ("roc_bert", ("RoCBertTokenizer", None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), - ("rwkv", (None, "GPTNeoXTokenizerFasts" if is_tokenizers_available() else None)), + ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 199be63b9211ba..f7365a68e0868e 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -372,7 +372,7 @@ def __init__(self, config, layer_id): self.attention = RwkvSelfAttention(config, layer_id) self.feed_forward = RwkvFeedForward(config, layer_id) - def forward(self, hidden, state=None, use_cache=False): + def forward(self, hidden, state=None, use_cache=False, output_attentions=False): if self.layer_id == 0: hidden = self.pre_ln(hidden) @@ -381,7 +381,14 @@ def forward(self, hidden, state=None, use_cache=False): feed_forward, state = self.feed_forward(self.ln2(hidden), state=state) hidden = hidden + feed_forward - return hidden, state + + outputs = (hidden, state) + if output_attentions: + outputs += (attention,) + else: + outputs += (None,) + + return outputs class RwkvPreTrainedModel(PreTrainedModel): @@ -631,18 +638,21 @@ def forward( if use_cache and state is None: shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) state = [ - torch.zeros(*shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=input_ids.device) + torch.zeros( + *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device + ) for i in range(5) ] state[4] -= 1e30 hidden_states = inputs_embeds - # TODO: add this. - # all_self_attentions = () if output_attentions else None - # all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): - hidden_states, state = block(hidden_states, state=state, use_cache=use_cache) + hidden_states, state, attentions = block( + hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions + ) if ( self.layers_are_rescaled and self.config.rescale_every > 0 @@ -650,12 +660,26 @@ def forward( ): hidden_states = hidden_states / 2 + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + hidden_states = self.ln_out(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if not return_dict: - (hidden_states,) + return (hidden_states, state, all_hidden_states, all_self_attentions) - return RwkvOutput(last_hidden_state=hidden_states, state=state, hidden_states=None, attentions=None) + return RwkvOutput( + last_hidden_state=hidden_states, + state=state, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) def _rescale_layers(self): # Layers should be rescaled for inference only. diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index 648cf08574c176..2c023277d260c8 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -15,6 +15,7 @@ import unittest +from unittest.util import safe_repr from transformers import RwkvConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device @@ -188,6 +189,7 @@ def prepare_config_and_inputs_for_decoder(self): ) def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + config.output_hidden_states = True model = RwkvModel(config=config) model.to(torch_device) model.eval() @@ -195,7 +197,7 @@ def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - self.parent.assertEqual(len(result.hidden_states), config.n_layer) + self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1) def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = RwkvForCausalLM(config) @@ -282,18 +284,27 @@ def setUp(self): self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) - def test_initialization(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + def assertInterval(self, member, container, msg=None): + r""" + Simple utility function to check if a member is inside an interval. + """ + if isinstance(member, torch.Tensor): + max_value, min_value = member.max().item(), member.min().item() + elif isinstance(member, list) or isinstance(member, tuple): + max_value, min_value = max(member), min(member) - for model_class in self.all_model_classes: - model = model_class(config=config) - for name, param in model.named_parameters(): - if param.requires_grad: - self.assertIn( - ((param.data.mean() * 1e9).round() / 1e9).item(), - [0.0, 1.0], - msg=f"Parameter {name} of model {model_class} seems not properly initialized", - ) + if not isinstance(container, list): + raise TypeError("container should be a list or tuple") + elif len(container) != 2: + raise ValueError("container should have 2 elements") + + expected_min, expected_max = container + + is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max) + + if not is_inside_interval: + standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container)) + self.fail(self._formatMessage(msg, standardMsg)) def test_config(self): self.config_tester.run_common_tests() @@ -310,6 +321,103 @@ def test_state_equivalency(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_state_equivalency(*config_and_inputs) + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if "time_decay" in name: + if param.requires_grad: + self.assertTrue(param.data.max().item() == 3.0) + self.assertTrue(param.data.min().item() == -5.0) + elif "time_first" in name: + if param.requires_grad: + # check if it's a ones like + self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) + elif any([x in name for x in ["time_mix_key", "time_mix_receptance"]]): + if param.requires_grad: + self.assertInterval( + param.data, + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif "time_mix_value" in name: + if param.requires_grad: + self.assertInterval( + param.data, + [0.0, 1.3], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the attention outputs of Rwkv are different from other models + it has a shape `batch_size, seq_len, hidden_size`. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + batch_size = inputs["input_ids"].shape[0] + with torch.no_grad(): + outputs = model(**inputs) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + batch_size = inputs["input_ids"].shape[0] + with torch.no_grad(): + outputs = model(**inputs) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [batch_size, seq_len, config.hidden_size], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + batch_size = inputs["input_ids"].shape[0] + with torch.no_grad(): + outputs = model(**inputs) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [batch_size, seq_len, config.hidden_size], + ) + @slow def test_model_from_pretrained(self): for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: From ffb55f99d55b563a6d5de31bbb303d4d34e77d8d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 May 2023 14:09:35 +0000 Subject: [PATCH 14/27] fix conversion script - fix sharded case - add new arguments --- .../rwkv/convert_rwkv_checkpoint_to_hf.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py index 37561cd3716752..9d5cb84f176305 100644 --- a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py +++ b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert a RKWV checkpoint form BlinkDL to the Hugging Face format.""" +"""Convert a RWKV checkpoint from BlinkDL to the Hugging Face format.""" import argparse @@ -24,7 +24,7 @@ import torch from huggingface_hub import hf_hub_download -from transformers import PreTrainedTokenizerFast, RwkvConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint @@ -78,15 +78,18 @@ def convert_state_dict(state_dict): return state_dict -def convert_rmkv_checkpoint_to_hf_format(repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None): +def convert_rmkv_checkpoint_to_hf_format( + repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None +): # 1. If possible, build the tokenizer. if tokenizer_file is None: - print("No `--tokenizer_file` provided, we will only convert the model.") + print("No `--tokenizer_file` provided, we will use the default tokenizer.") vocab_size = 50277 + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") else: tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) vocab_size = len(tokenizer) - tokenizer.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) # 2. Build the config possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys()) @@ -136,9 +139,16 @@ def convert_rmkv_checkpoint_to_hf_format(repo_id, checkpoint_file, output_dir, s gc.collect() for shard_file in shard_files: - state_dict = torch.load(shard_file) + state_dict = torch.load(os.path.join(output_dir, shard_file)) torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, shard_file) + if push_to_hub: + if model_name is None: + raise ValueError("Please provide a `model_name` to push the model to the Hub.") + model = AutoModelForCausalLM.from_pretrained(output_dir) + model.push_to_hub(model_name, max_shard_size="2GB") + tokenizer.push_to_hub(model_name) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -164,8 +174,25 @@ def convert_rmkv_checkpoint_to_hf_format(repo_id, checkpoint_file, output_dir, s type=str, help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.", ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push to the Hub the converted model.", + ) + parser.add_argument( + "--model_name", + default=None, + type=str, + help="Name of the pushed model on the Hub, including the username / organization.", + ) args = parser.parse_args() convert_rmkv_checkpoint_to_hf_format( - args.repo_id, args.checkpoint_file, args.output_dir, size=args.size, tokenizer_file=args.tokenizer_file + args.repo_id, + args.checkpoint_file, + args.output_dir, + size=args.size, + tokenizer_file=args.tokenizer_file, + push_to_hub=args.push_to_hub, + model_name=args.model_name, ) From 7fd5702c2c088913ebf53d4e57c74bcf0db40b41 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 May 2023 14:38:32 +0000 Subject: [PATCH 15/27] add slow tests + more fixes on conversion script --- .../rwkv/convert_rwkv_checkpoint_to_hf.py | 2 +- tests/models/rwkv/test_modeling_rwkv.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py index 9d5cb84f176305..3f52cde2a26661 100644 --- a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py +++ b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py @@ -117,7 +117,7 @@ def convert_rmkv_checkpoint_to_hf_format( state_dict = convert_state_dict(state_dict) # 4. Split in shards and save - shards, index = shard_checkpoint(state_dict) + shards, index = shard_checkpoint(state_dict, max_shard_size="2GB") for shard_file, shard in shards.items(): torch.save(shard, os.path.join(output_dir, shard_file)) diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index 2c023277d260c8..f79f0c66060459 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -17,7 +17,7 @@ import unittest from unittest.util import safe_repr -from transformers import RwkvConfig, is_torch_available +from transformers import AutoTokenizer, RwkvConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -423,3 +423,20 @@ def test_model_from_pretrained(self): for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = RwkvModel.from_pretrained(model_name) self.assertIsNotNone(model) + + +@slow +class RWKVIntegrationTests(unittest.TestCase): + def setUp(self): + self.model_id = "ybelkada/rwkv-4-169m-pile" + self.model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + def test_simple_generate(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + output = self.model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) From e52be94dd1b7dba790535aa240963eb7ff77825a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 May 2023 15:02:28 +0000 Subject: [PATCH 16/27] add another test --- tests/models/rwkv/test_modeling_rwkv.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index f79f0c66060459..1813aca87c497e 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -429,14 +429,25 @@ def test_model_from_pretrained(self): class RWKVIntegrationTests(unittest.TestCase): def setUp(self): self.model_id = "ybelkada/rwkv-4-169m-pile" - self.model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) def test_simple_generate(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" + model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device) input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - output = self.model.generate(input_ids, max_new_tokens=10) + output = model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) + + def test_simple_generate_bf16(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + model = RwkvForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) + + output = model.generate(input_ids, max_new_tokens=10) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) From 3c02506a5f310b0ab7ad53855eb3458a24ad984f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 May 2023 15:22:54 +0000 Subject: [PATCH 17/27] final fixes --- src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py | 3 +++ src/transformers/models/rwkv/modeling_rwkv.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py index 3f52cde2a26661..24cf64b03ac576 100644 --- a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py +++ b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py @@ -142,6 +142,9 @@ def convert_rmkv_checkpoint_to_hf_format( state_dict = torch.load(os.path.join(output_dir, shard_file)) torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, shard_file) + del state_dict + gc.collect() + if push_to_hub: if model_name is None: raise ValueError("Please provide a `model_name` to push the model to the Hub.") diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index f7365a68e0868e..5c7621f16480ad 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -399,7 +399,6 @@ class RwkvPreTrainedModel(PreTrainedModel): config_class = RwkvConfig base_model_prefix = "rwkv" - # supports_gradient_checkpointing = True _no_split_modules = ["RwkvBlock"] _keep_in_fp32_modules = ["time_decay", "time_first"] From 69774da90b5d564540cf7fc87d83c9f22e159dc3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 4 May 2023 15:28:39 +0000 Subject: [PATCH 18/27] change single name variable --- src/transformers/models/rwkv/modeling_rwkv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 5c7621f16480ad..47b323dd6bad70 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -212,9 +212,9 @@ def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, re time_decay = -torch.exp(time_decay) - for t in range(seq_length): - current_key = key[:, t].float() - current_value = value[:, t] + for current_index in range(seq_length): + current_key = key[:, current_index].float() + current_value = value[:, current_index] # wkv computation at time t max_for_output = torch.maximum(max_state, current_key + time_first) @@ -222,7 +222,7 @@ def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, re e2 = torch.exp(current_key + time_first - max_for_output) numerator = e1 * num_state + e2 * current_value denominator = e1 * den_state + e2 - output[:, t] = (numerator / denominator).to(output.dtype) + output[:, current_index] = (numerator / denominator).to(output.dtype) # Update state for next iteration max_for_state = torch.maximum(max_state + time_decay, current_key) From d52d6dbc717df6b6f17ef755db7b3c7152873699 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 5 May 2023 08:49:57 +0000 Subject: [PATCH 19/27] add mock attention mask for pipeline to work --- src/transformers/models/rwkv/modeling_rwkv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 47b323dd6bad70..e51312a71e687d 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -742,6 +742,7 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non def forward( self, input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, state: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, From 0d758e13eeb40613824873cc71e0ca307bdf5ecf Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 5 May 2023 09:51:17 +0000 Subject: [PATCH 20/27] correct eos token id --- src/transformers/models/rwkv/configuration_rwkv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index 35d5ddcb3a395a..48e6dcdd72fff7 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -57,8 +57,8 @@ class RwkvConfig(PretrainedConfig): bos_token_id (`int`, *optional*, defaults to 0): The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as GPTNeoX. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the end of sentence token in the vocabulary. Defaults to 2 as RWKV uses the same tokenizer as + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as GPTNeoX. rescale_every (`int`, *optional*, default to 6): At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every @@ -96,7 +96,7 @@ def __init__( intermediate_size=None, layer_norm_epsilon=1e-5, bos_token_id=0, - eos_token_id=2, + eos_token_id=0, rescale_every=6, tie_word_embeddings=False, use_cache=True, From df44a6060a7df94f8164862119481e6e518173c9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 5 May 2023 11:49:13 +0000 Subject: [PATCH 21/27] fix nits --- src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py index 24cf64b03ac576..b340b9f028b3d7 100644 --- a/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py +++ b/src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py @@ -117,7 +117,7 @@ def convert_rmkv_checkpoint_to_hf_format( state_dict = convert_state_dict(state_dict) # 4. Split in shards and save - shards, index = shard_checkpoint(state_dict, max_shard_size="2GB") + shards, index = shard_checkpoint(state_dict) for shard_file, shard in shards.items(): torch.save(shard, os.path.join(output_dir, shard_file)) @@ -140,7 +140,7 @@ def convert_rmkv_checkpoint_to_hf_format( for shard_file in shard_files: state_dict = torch.load(os.path.join(output_dir, shard_file)) - torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, shard_file) + torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file)) del state_dict gc.collect() From 9e3efc5a7af85be9e3de7288cd9a96b9470dead1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 5 May 2023 14:11:11 +0000 Subject: [PATCH 22/27] add checkpoints --- src/transformers/models/rwkv/configuration_rwkv.py | 13 +++++++++++-- src/transformers/models/rwkv/modeling_rwkv.py | 13 +++++++++++-- tests/models/rwkv/test_modeling_rwkv.py | 2 +- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index 48e6dcdd72fff7..91fd2257353c72 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -22,7 +22,16 @@ logger = logging.get_logger(__name__) RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "sgugger/rwkv-4-pile-7b": "https://huggingface.co/sgugger/rwkv-4-pile-7b/resolve/main/config.json", + "RWKV/rwkv-4-169m-pile": "https://huggingface.co/RWKV/rwkv-4-169m-pile/resolve/main/config.json", + "RWKV/rwkv-4-430m-pile": "https://huggingface.co/RWKV/rwkv-4-430m-pile/resolve/main/config.json", + "RWKV/rwkv-4-1b5-pile": "https://huggingface.co/RWKV/rwkv-4-1b5-pile/resolve/main/config.json", + "RWKV/rwkv-4-3b-pile": "https://huggingface.co/RWKV/rwkv-4-3b-pile/resolve/main/config.json", + "RWKV/rwkv-4-7b-pile": "https://huggingface.co/RWKV/rwkv-4-7b-pile/resolve/main/config.json", + "RWKV/rwkv-4-14b-pile": "https://huggingface.co/RWKV/rwkv-4-14b-pile/resolve/main/config.json", + "RWKV/rwkv-raven-1b5": "https://huggingface.co/RWKV/rwkv-raven-1b5/resolve/main/config.json", + "RWKV/rwkv-raven-3b": "https://huggingface.co/RWKV/rwkv-raven-3b/resolve/main/config.json", + "RWKV/rwkv-raven-7b": "https://huggingface.co/RWKV/rwkv-raven-7b/resolve/main/config.json", + "RWKV/rwkv-raven-14b": "https://huggingface.co/RWKV/rwkv-raven-14b/resolve/main/config.json", } @@ -31,7 +40,7 @@ class RwkvConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the RWVK-4 - [sgugger/rwkv-4-pile-7b](https://huggingface.co/sgugger/rwkv-4-pile-7b) architecture. + [RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index e51312a71e687d..b6229ed6eef13a 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -40,11 +40,20 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "sgugger/rwkv-4-pile-7b" +_CHECKPOINT_FOR_DOC = "RWKV/rwkv-4-169m-pile" _CONFIG_FOR_DOC = "RwkvConfig" RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "sgugger/rwkv-4-pile-7b", + "RWKV/rwkv-4-169m-pile", + "RWKV/rwkv-4-430m-pile", + "RWKV/rwkv-4-1b5-pile", + "RWKV/rwkv-4-3b-pile", + "RWKV/rwkv-4-7b-pile", + "RWKV/rwkv-4-14b-pile", + "RWKV/rwkv-raven-1b5", + "RWKV/rwkv-raven-3b", + "RWKV/rwkv-raven-7b", + "RWKV/rwkv-raven-14b", # See all RWKV models at https://huggingface.co/models?filter=rwkv ] diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index 1813aca87c497e..165e00775dab02 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -428,7 +428,7 @@ def test_model_from_pretrained(self): @slow class RWKVIntegrationTests(unittest.TestCase): def setUp(self): - self.model_id = "ybelkada/rwkv-4-169m-pile" + self.model_id = "RWKV/rwkv-4-169m-pile" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) def test_simple_generate(self): From 0b095c7905dd7c0c3ebcb361913661a9b75a5281 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 5 May 2023 17:03:15 +0200 Subject: [PATCH 23/27] Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/models/rwkv/configuration_rwkv.py | 6 ++---- src/transformers/models/rwkv/modeling_rwkv.py | 6 +++--- tests/models/rwkv/test_modeling_rwkv.py | 4 +--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index 91fd2257353c72..6d67f4510779ac 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -61,7 +61,7 @@ class RwkvConfig(PretrainedConfig): Dimensionality of the attention hidden states. Will default to `hidden_size` if unset. intermediate_size (`int`, *optional*): Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset. - layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + layer_norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. bos_token_id (`int`, *optional*, defaults to 0): The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer @@ -91,9 +91,7 @@ class RwkvConfig(PretrainedConfig): ```""" model_type = "rwkv" - attribute_map = { - "max_position_embeddings": "context_length", - } + attribute_map = {"max_position_embeddings": "context_length"} def __init__( self, diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index b6229ed6eef13a..5a7c09c9e3e499 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright 2023 Peng Bo and HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -573,6 +573,8 @@ class RwkvCausalLMOutput(ModelOutput): state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the last state is returned and can be used to quickly generate the next logits. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -581,8 +583,6 @@ class RwkvCausalLMOutput(ModelOutput): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - use_cache (`bool`, *optional*): - If set to `True`, the last state is returned and can be used to quickly generate the next logits. """ diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py index 165e00775dab02..4afcc9b41f8682 100644 --- a/tests/models/rwkv/test_modeling_rwkv.py +++ b/tests/models/rwkv/test_modeling_rwkv.py @@ -253,9 +253,7 @@ def prepare_config_and_inputs_for_common(self): choice_labels, ) = config_and_inputs - inputs_dict = { - "input_ids": input_ids, - } + inputs_dict = {"input_ids": input_ids} return config, inputs_dict From f912a0d5f9143f86d5c3103707a17f2049586d1d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 5 May 2023 15:05:20 +0000 Subject: [PATCH 24/27] add `tie_word_embeddings` in docstring --- src/transformers/models/rwkv/configuration_rwkv.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/rwkv/configuration_rwkv.py b/src/transformers/models/rwkv/configuration_rwkv.py index 6d67f4510779ac..89b2f5fb648391 100644 --- a/src/transformers/models/rwkv/configuration_rwkv.py +++ b/src/transformers/models/rwkv/configuration_rwkv.py @@ -72,9 +72,12 @@ class RwkvConfig(PretrainedConfig): rescale_every (`int`, *optional*, default to 6): At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every `rescale_every` layer. If set to 0 or a negative number, no rescale is done. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the input token embeddings. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last state. + Example: ```python From 2a6dd6175762d4fbff4c13c6ded899bdfa4a90bd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 5 May 2023 15:08:59 +0000 Subject: [PATCH 25/27] change tensor name --- src/transformers/models/rwkv/modeling_rwkv.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 5a7c09c9e3e499..5ac1dd777fbda1 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -422,12 +422,12 @@ def _init_weights(self, module): ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 - ddd = torch.tensor( + time_weight = torch.tensor( [i / hidden_size for i in range(hidden_size)], dtype=module.time_mix_key.dtype, device=module.time_mix_key.device, ) - ddd = ddd[None, None, :] + time_weight = time_weight[None, None, :] decay_speed = [ -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) @@ -447,9 +447,9 @@ def _init_weights(self, module): module.time_decay.data = decay_speed module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag) - module.time_mix_key.data = torch.pow(ddd, ratio_1_to_almost0) - module.time_mix_value.data = torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 - module.time_mix_receptance.data = torch.pow(ddd, 0.5 * ratio_1_to_almost0) + module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) elif isinstance(module, RwkvFeedForward): layer_id = module.layer_id num_hidden_layers = module.config.num_hidden_layers @@ -457,16 +457,16 @@ def _init_weights(self, module): ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 - ddd = torch.tensor( + time_weight = torch.tensor( [i / hidden_size for i in range(hidden_size)], dtype=module.time_mix_key.dtype, device=module.time_mix_key.device, ) - ddd = ddd[None, None, :] + time_weight = time_weight[None, None, :] with torch.no_grad(): - module.time_mix_key.data = torch.pow(ddd, ratio_1_to_almost0) - module.time_mix_receptance.data = torch.pow(ddd, ratio_1_to_almost0) + module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, RwkvModel): From 4fdbeefa2ef9759d1c0f48107c5b9807c56e5cd0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 9 May 2023 09:19:50 +0000 Subject: [PATCH 26/27] fix final nits --- README.md | 2 +- README_es.md | 2 +- README_hd.md | 2 +- README_ja.md | 2 +- README_ko.md | 2 +- README_zh-hans.md | 2 +- README_zh-hant.md | 2 +- docs/source/en/index.mdx | 2 +- src/transformers/models/rwkv/modeling_rwkv.py | 3 +-- 9 files changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 1f368f9009d784..fdceaa1643e006 100644 --- a/README.md +++ b/README.md @@ -422,7 +422,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Peng Bo), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/README_es.md b/README_es.md index 716d4a080978d1..c3c569c531d60c 100644 --- a/README_es.md +++ b/README_es.md @@ -410,7 +410,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Peng Bo) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/README_hd.md b/README_hd.md index 5e68a690043595..af9c58011455cd 100644 --- a/README_hd.md +++ b/README_hd.md @@ -382,7 +382,7 @@ conda install -c huggingface transformers 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (झुईई टेक्नोलॉजी से), साथ में पेपर [रोफॉर्मर: रोटरी पोजिशन एंबेडिंग के साथ एन्हांस्ड ट्रांसफॉर्मर] (https://arxiv.org/pdf/2104.09864v1.pdf) जियानलिन सु और यू लू और शेंगफेंग पैन और बो वेन और युनफेंग लियू द्वारा प्रकाशित। -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Peng Bo से) Peng Bo. द्वाराअनुसंधान पत्र [this repo](https://github.com/BlinkDL/RWKV-LM) के साथ जारी किया गया +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Bo Peng से) Bo Peng. द्वाराअनुसंधान पत्र [this repo](https://github.com/BlinkDL/RWKV-LM) के साथ जारी किया गया 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI से) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. द्वाराअनुसंधान पत्र [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) के साथ जारी किया गया 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP से) साथ देने वाला पेपर [भाषण पहचान के लिए अनसुपरवाइज्ड प्री-ट्रेनिंग में परफॉर्मेंस-एफिशिएंसी ट्रेड-ऑफ्स](https ://arxiv.org/abs/2109.06870) फेलिक्स वू, क्वांगयुन किम, जिंग पैन, क्यू हान, किलियन क्यू. वेनबर्गर, योव आर्टज़ी द्वारा। diff --git a/README_ja.md b/README_ja.md index e21fa792ab32a5..953ff5598e4af7 100644 --- a/README_ja.md +++ b/README_ja.md @@ -444,7 +444,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (Facebook から) Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli から公開された研究論文: [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI から) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou から公開された研究論文: [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology から), Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu から公開された研究論文: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Peng Bo から) Peng Bo. から公開された研究論文 [this repo](https://github.com/BlinkDL/RWKV-LM) +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Bo Peng から) Bo Peng. から公開された研究論文 [this repo](https://github.com/BlinkDL/RWKV-LM) 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA から) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo から公開された研究論文: [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI から) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. から公開された研究論文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP から) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi から公開された研究論文: [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) diff --git a/README_ko.md b/README_ko.md index 9dcf109f548539..2707f191dad0b6 100644 --- a/README_ko.md +++ b/README_ko.md @@ -359,7 +359,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (Facebook 에서) Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli 의 [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 논문과 함께 발표했습니다. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI 에서) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 의 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 논문과 함께 발표했습니다. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology 에서) Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 의 a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 논문과 함께 발표했습니다. -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Peng Bo 에서 제공)은 Peng Bo.의 [this repo](https://github.com/BlinkDL/RWKV-LM)논문과 함께 발표했습니다. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Bo Peng 에서 제공)은 Bo Peng.의 [this repo](https://github.com/BlinkDL/RWKV-LM)논문과 함께 발표했습니다. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA 에서) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 의 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 논문과 함께 발표했습니다. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI 에서 제공)은 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.의 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf)논문과 함께 발표했습니다. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP 에서) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 의 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 5417886f0c0837..74d39b1df3d6aa 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -383,7 +383,7 @@ conda install -c huggingface transformers 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (来自 Facebook) 伴随论文 [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 由 Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli 发布。 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (来自 WeChatAI), 伴随论文 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 由 HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 发布。 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。 -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (来自 Peng Bo) 伴随论文 [this repo](https://github.com/BlinkDL/RWKV-LM) 由 Peng Bo 发布。 +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (来自 Bo Peng) 伴随论文 [this repo](https://github.com/BlinkDL/RWKV-LM) 由 Bo Peng 发布。 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (来自 Meta AI) 伴随论文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 由 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick 发布。 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (来自 ASAPP) 伴随论文 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 由 Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 591a5aa18b4394..4fd2a3caad5491 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -395,7 +395,7 @@ conda install -c huggingface transformers 1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Peng Bo) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. +1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index f1b86db8711993..c9af68ae078e88 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -196,7 +196,7 @@ The documentation is organized into five sections: 1. **[RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli. 1. **[RoCBert](model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -1. **[RWKV](model_doc/rwkv)** (from Peng Bo), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Peng Bo. +1. **[RWKV](model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. 1. **[SegFormer](model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 5ac1dd777fbda1..dd85c279daeb60 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2023 Peng Bo and HuggingFace Inc. team. +# Copyright 2023 Bo Peng and HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -145,7 +145,6 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa forward_func = rwkv_cuda_kernel.forward_with_state_bf16 else: forward_func = rwkv_cuda_kernel.forward_with_state - # TODO: update CUDA kernel so it uses the initial state provided here. forward_func(time_decay, time_first, key, value, output, state) else: forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward From d1c2b15e67c141f1ddff4ad0eaed7f6eb934fead Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 9 May 2023 11:56:40 -0400 Subject: [PATCH 27/27] Trigger CI