Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions ggml/src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,19 @@ static __global__ void rope_norm(

const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;

if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;

const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;

if (i0 >= n_dims) {
dst[idst + 0] = x[ix + 0];
dst[idst + 1] = x[ix + 1];

return;
}

const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
Expand Down Expand Up @@ -94,21 +92,19 @@ static __global__ void rope_neox(

const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;

if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;

const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;

if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];

return;
}

const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
Expand Down Expand Up @@ -138,21 +134,19 @@ static __global__ void rope_multi(

const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;

if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;

const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;

if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];

return;
}

const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
Expand Down
33 changes: 15 additions & 18 deletions ggml/src/ggml-sycl/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const

const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);

if (i0 >= n_dims) {
const int i = row * ne0 + i0;
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
return;
}

const int row0 = row % ne1;
const int channel0 = row / ne1;

const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0;

if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
return;
}

const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);

const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
Expand Down Expand Up @@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const

const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);

if (i0 >= n_dims) {
const int i = row * ne0 + i0;
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
return;
}

const int row0 = row % ne1;
const int channel0 = row / ne1;

const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;

if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
return;
}

const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);

const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
Expand Down Expand Up @@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);

if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
return;
}

const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);

if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
return;
}

const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
Expand Down
20 changes: 14 additions & 6 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5323,12 +5323,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (bool fw : {true, false}) { // fw == forward
bool all = true;

for (float v : { 0, 1 }) {
for (float fs : { 1.0f, 1.4245f }) {
for (float ef : { 0.0f, 0.7465f }) {
for (float af : { 1.0f, 1.4245f }) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (bool ff : {false, true}) { // freq_factors
for (float fs : { 1.0f, 1.4245f }) {
for (float ef : { 0.0f, 0.7465f }) {
for (float af : { 1.0f, 1.4245f }) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (bool ff : {false, true}) { // freq_factors
for (float v : { 0, 1 }) {
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B

if (all) {
Expand All @@ -5341,13 +5341,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)

test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));

test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
}

if (all) {
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw));
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
}

Expand Down
Loading