Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
const bool is_neox = mode & 2;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;

if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
Expand Down Expand Up @@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
// both mrope and vision kernels have sections
if (is_mrope || is_vision) {
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
}
// only mrope has is_imrope
if (is_mrope && !is_vision) {
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
}

size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
Expand Down
74 changes: 50 additions & 24 deletions ggml/src/ggml-opencl/kernels/rope.cl
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
int4 sections,
int is_imrope
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
Expand All @@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;

if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
theta_base = (float) pos[i2 + ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
theta_base = (float) pos[i2 + ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
theta_base = (float) pos[i2 + ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + ne02 * 3];
}
} else {
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
}

const float theta = theta_base * pow(freq_base, inv_ndims*i0);
Expand Down Expand Up @@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
int4 sections,
int is_imrope
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
Expand All @@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;

if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
theta_base = (float) pos[i2 + ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
theta_base = (float) pos[i2 + ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
theta_base = (float) pos[i2 + ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + ne02 * 3];
}
} else {
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
}

const float theta = theta_base * pow(freq_base, inv_ndims*i0);
Expand Down
Loading