From 2a9b730890853eb92d9aa1345f5936371d4a2263 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Jul 2025 08:56:53 +0300 Subject: [PATCH 1/4] cuda : fix rope non-cont ggml-ci --- ggml/src/ggml-cuda/rope.cu | 32 ++++++++++++++++---------------- tests/test-backend-ops.cpp | 18 ++++++++++++------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 18f691b2d3103..ba74368024279 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -50,21 +50,21 @@ static __global__ void rope_norm( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0; + const int ix = channel_x*s2 + row_x*s1 + i0; + if (i0 >= n_dims) { const int i = row_dst*ne0 + i0; - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; + dst[i + 0] = x[ix + 0]; + dst[i + 1] = x[ix + 1]; return; } - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - const int idst = row_dst*ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -94,21 +94,21 @@ static __global__ void rope_neox( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { const int i = row_dst*ne0 + i0; - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; + dst[i + 0] = x[ix + i0/2 + 0]; + dst[i + 1] = x[ix + i0/2 + 1]; return; } - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 652856a35d5e7..cf195d3577db5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5323,12 +5323,12 @@ static std::vector> make_test_cases_eval() { for (bool fw : {true, false}) { // fw == forward bool all = true; - for (float v : { 0, 1 }) { - for (float fs : { 1.0f, 1.4245f }) { - for (float ef : { 0.0f, 0.7465f }) { - for (float af : { 1.0f, 1.4245f }) { - for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (bool ff : {false, true}) { // freq_factors + for (float fs : { 1.0f, 1.4245f }) { + for (float ef : { 0.0f, 0.7465f }) { + for (float af : { 1.0f, 1.4245f }) { + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (bool ff : {false, true}) { // freq_factors + for (float v : { 0, 1 }) { test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { @@ -5341,8 +5341,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { From 31af27a74b14c7d1d165a0f756ff33c955c99be6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Jul 2025 09:14:48 +0300 Subject: [PATCH 2/4] cont : fix multi-rope + add test ggml-ci --- ggml/src/ggml-cuda/rope.cu | 16 ++++++++-------- tests/test-backend-ops.cpp | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index ba74368024279..476cf398fd050 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -138,21 +138,21 @@ static __global__ void rope_multi( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { const int i = row_dst*ne0 + i0; - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; + dst[i + 0] = x[ix + i0/2 + 0]; + dst[i + 1] = x[ix + i0/2 + 1]; return; } - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; - const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cf195d3577db5..b54bcc8a35e64 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5342,9 +5342,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) @@ -5354,6 +5354,8 @@ static std::vector> make_test_cases_eval() { if (all) { test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } From 96998d7cf2338fd564c7754557806caac28ff05e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Jul 2025 09:25:44 +0300 Subject: [PATCH 3/4] sycl : try fix ggml-ci --- ggml/src/ggml-sycl/rope.cpp | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index e44c6b6ef8f42..3f14a6e56af76 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -47,18 +47,18 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0; const int i2 = channel0 * s2 + row0 * s1 + i0; + if (i0 >= n_dims) { + const int i = row * ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -88,18 +88,18 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0 / 2; const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + if (i0 >= n_dims) { + const int i = row * ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i0 / 2 + i); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -129,17 +129,17 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const } const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = (row_dst * ne0) + (i0 / 2); const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + if (i0 >= n_dims) { + const int i = row_dst*ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i0 / 2 + i); + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; From bcbf7bc86a3ca2b548531ebd1dd6d57dbfcd1c06 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Jul 2025 09:45:17 +0300 Subject: [PATCH 4/4] cont : fix sycl + clean-up cuda ggml-ci --- ggml/src/ggml-cuda/rope.cu | 18 ++++++------------ ggml/src/ggml-sycl/rope.cpp | 7 ++----- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 476cf398fd050..d058504cd6cc0 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -57,10 +57,8 @@ static __global__ void rope_norm( const int ix = channel_x*s2 + row_x*s1 + i0; if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[ix + 0]; - dst[i + 1] = x[ix + 1]; + dst[idst + 0] = x[ix + 0]; + dst[idst + 1] = x[ix + 1]; return; } @@ -101,10 +99,8 @@ static __global__ void rope_neox( const int ix = channel_x*s2 + row_x*s1 + i0/2; if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[ix + i0/2 + 0]; - dst[i + 1] = x[ix + i0/2 + 1]; + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; return; } @@ -145,10 +141,8 @@ static __global__ void rope_multi( const int ix = channel_x*s2 + row_x*s1 + i0/2; if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[ix + i0/2 + 0]; - dst[i + 1] = x[ix + i0/2 + 1]; + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; return; } diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 3f14a6e56af76..1b60226dcd531 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -54,7 +54,6 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const const int i2 = channel0 * s2 + row0 * s1 + i0; if (i0 >= n_dims) { - const int i = row * ne0 + i0; *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2); return; } @@ -95,8 +94,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i0 / 2 + i); + *reinterpret_cast *>(dst + i + i0 / 2) = *reinterpret_cast *>(x + i2 + i0 / 2); return; } @@ -135,8 +133,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i0 / 2 + i); + *reinterpret_cast *>(dst + idst + i0 / 2) = *reinterpret_cast *>(x + i0 / 2 + ix); return; }