Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,25 +1263,30 @@ struct vk_op_diag_mask_push_constants {

struct vk_op_rope_push_constants {
uint32_t rope_mode;
uint32_t ncols;
uint32_t nrows;
uint32_t n_dims;
float freq_scale;
uint32_t p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
float corr_dims[2];
float theta_scale;
uint32_t has_ff;
uint32_t ne02;
uint32_t s1;
uint32_t s2;
int32_t sections[4];
uint32_t is_imrope;
uint32_t is_back;
uint32_t set_rows_stride;
uint32_t ne00;
uint32_t ne01;
uint32_t ne02;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
};
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");

// For fused rms_norm+mul+rope(+view+set_rows)
struct vk_op_rms_norm_mul_rope_push_constants {
Expand Down Expand Up @@ -10405,12 +10410,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *

uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);

uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);

vk_op_rope_push_constants rope {
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
(uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,

(uint32_t)src0->ne[0],
(uint32_t)src0->ne[1],
(uint32_t)src0->ne[2],
nb01, nb02, nb03,
nb11, nb12, nb13,
};

return rope;
Expand Down Expand Up @@ -14798,6 +14813,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_REPEAT_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ROPE:
return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ROPE_BACK:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
Expand Down
5 changes: 2 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
#if RMS_NORM_ROPE_FUSION
barrier();
rope_params rp = p.rope;
uint rope_row = (samp*nchannels + channel)*nrows + row;
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
rope_neox(t, rope_row, rp);
rope_neox(t, row, channel, samp, rp);
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
rope_norm(t, rope_row, rp);
rope_norm(t, row, channel, samp, rp);
}
}
#endif
Expand Down
99 changes: 36 additions & 63 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
return 1.0f - min(1.0f, max(0.0f, y));
}

uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
#if RMS_NORM_ROPE_FUSION
// Per-row offset in shared memory
const uint ix = i0;
#else
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
#endif
return ix;
}
Expand All @@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
sin_theta = sin(theta) * mscale;
}

void rope_norm(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;

if (i0 >= ne0) {
void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
if (i0 >= p.ne00) {
return;
}

// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

uint idst = i1*ne0 + i0;
const uint ix = rope_a_coord(i0, i01, i02, p);
uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
const uint ix = rope_a_coord(i0, i1, i2, i3, p);

// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
idst = i01*ne0 + i0;
idst += rope_data_i[i02].x * p.set_rows_stride;
idst = i1*p.nb11 + i0;
idst += rope_data_i[i2].x * p.set_rows_stride;
}

if (i0 >= p.n_dims) {
Expand All @@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
return;
}

const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;

Expand All @@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}

void rope_neox(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;

if (i0 >= ne0) {
void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
if (i0 >= p.ne00) {
return;
}

const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

uint idst = i1*ne0 + i0/2;
const uint ix = rope_a_coord(i0/2, i01, i02, p);
uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);

// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
idst = i01*ne0 + i0/2;
idst += rope_data_i[i02].x * p.set_rows_stride;
idst = i1*p.nb11 + i0/2;
idst += rope_data_i[i2].x * p.set_rows_stride;
}

if (i0 >= p.n_dims) {
Expand All @@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
return;
}

const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;

Expand All @@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
}


void rope_multi(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;

if (i0 >= ne0) {
void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
if (i0 >= p.ne00) {
return;
}

const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

uint idst = i1*ne0 + i0/2;
const uint ix = rope_a_coord(i0/2, i01, i02, p);
uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);

// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
idst = i01*ne0 + i0/2;
idst += rope_data_i[i02].x * p.set_rows_stride;
idst = i1*p.nb11 + i0/2;
idst += rope_data_i[i2].x * p.set_rows_stride;
}

if (i0 >= p.n_dims) {
Expand All @@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
float theta_base = 0.0;
if (p.is_imrope != 0) {
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
} else {
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
}
}

Expand All @@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}

void rope_vision(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;

if (i0 >= ne0) {
void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
if (i0 >= p.ne00) {
return;
}

const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

const uint idst = i1*ne0 + i0/2;
const uint ix = rope_a_coord(i0/2, i01, i02, p);
const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);

const int sect_dims = p.sections[0] + p.sections[1];
const int sec_w = p.sections[1] + p.sections[0];
Expand All @@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
float theta_base = 0.0;
if (sector < p.sections[0]) {
const uint p0 = sector;
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
}
else if (sector >= p.sections[0] && sector < sec_w) {
const uint p0 = sector - p.sections[0];
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
}

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
Expand Down
11 changes: 7 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (row >= pc.nrows) {
return;
}
rope_multi(i0, i1, pc);
const uint i3 = row / (pc.ne01*pc.ne02);
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);

rope_multi(i0, i1, i2, i3, pc);
}
11 changes: 7 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (row >= pc.nrows) {
return;
}
rope_neox(i0, i1, pc);
const uint i3 = row / (pc.ne01*pc.ne02);
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);

rope_neox(i0, i1, i2, i3, pc);
}
11 changes: 7 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (row >= pc.nrows) {
return;
}
rope_norm(i0, i1, pc);
const uint i3 = row / (pc.ne01*pc.ne02);
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);

rope_norm(i0, i1, i2, i3, pc);
}
15 changes: 10 additions & 5 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,29 @@

struct rope_params {
uint rope_mode;
uint ncols;
uint nrows;
uint n_dims;
float freq_scale;
uint p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
float corr_dims[2];
float theta_scale;
uint has_ff;
uint ne02;
uint nb01;
uint nb02;
int sections[4];
uint is_imrope;
uint is_back;
uint set_rows_stride;

uint ne00;
uint ne01;
uint ne02;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
};

#endif // !defined(GGML_ROPE_PARAMS)
11 changes: 7 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (i1 >= pc.nrows) {
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (row >= pc.nrows) {
return;
}
rope_vision(i0, i1, pc);
const uint i3 = row / (pc.ne01*pc.ne02);
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);

rope_vision(i0, i1, i2, i3, pc);
}
Loading