diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 4b376261d30d2..0456b4bc263cc 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -20,6 +20,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 1152) \ f(in_T, out_T, W_T, narrow, 1280) \ f(in_T, out_T, W_T, narrow, 1536) \ + f(in_T, out_T, W_T, narrow, 1664) \ f(in_T, out_T, W_T, narrow, 1728) \ f(in_T, out_T, W_T, narrow, 1792) \ f(in_T, out_T, W_T, narrow, 2048) \ @@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ + f(in_T, out_T, W_T, narrow, 5888) \ f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ @@ -45,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 9216) \ f(in_T, out_T, W_T, narrow, 10240) \ f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 11264) \ f(in_T, out_T, W_T, narrow, 12288) \ f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ @@ -53,6 +56,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ + f(in_T, out_T, W_T, narrow, 22528) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27648) \ @@ -91,6 +95,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 1152, narrow) \ f(in_T, out_T, W_T, 1280, narrow) \ f(in_T, out_T, W_T, 1536, narrow) \ + f(in_T, out_T, W_T, 1664, narrow) \ f(in_T, out_T, W_T, 1728, narrow) \ f(in_T, out_T, W_T, 1792, narrow) \ f(in_T, out_T, W_T, 2048, narrow) \ @@ -107,6 +112,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 5120, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ + f(in_T, out_T, W_T, 5888, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \ f(in_T, out_T, W_T, 6400, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \ @@ -116,6 +122,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 9216, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 11264, narrow) \ f(in_T, out_T, W_T, 12288, narrow) \ f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ @@ -124,6 +131,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 16384, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 22528, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27648, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f021c003b1322..d87658e5dd886 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -53,6 +53,7 @@ def _lora_ref_impl( 1152, 1280, 1536, + 1664, 2048, 2304, 2560, @@ -66,6 +67,7 @@ def _lora_ref_impl( 5120, 5504, 5632, + 5888, 6144, 6400, 6848, @@ -75,10 +77,12 @@ def _lora_ref_impl( 9216, 10240, 11008, + 11264, 13824, 14336, 15360, 22016, + 22528, 24576, 27392, 27648,