Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2752,7 +2752,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
if (src0_2) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
dst->src[4], dst->src[5],
(const char *)src0_1->data, src0_2 ? (const char *)src0_2->data : nullptr,
(const char *)src0_1->data, (const char *)src0_2->data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
Expand All @@ -2763,7 +2763,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
if (!dst->src[4]) {
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst,
nullptr, nullptr,
(const char *)local_src0_1.data, (const char *)local_src0_2.data,
(const char *)local_src0_2.data, (const char *)local_src0_1.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
} else {
Expand All @@ -2773,8 +2773,8 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
auto local_bias_2 = local_bias_1;
local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0];
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst,
&local_bias_1, &local_bias_2,
(const char *)local_src0_1.data, (const char *)local_src0_2.data,
&local_bias_2, &local_bias_1,
(const char *)local_src0_2.data, (const char *)local_src0_1.data,
(const float *)src1->data, src1_quantized.get(),
(float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, limit, stream);
}
Expand Down Expand Up @@ -2922,7 +2922,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten

auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_up_gate_contiguous.get() + dst->ne[0], (const float *)dst_up_gate_contiguous.get(),
ggml_swiglu_oai_cuda_f32((const float *)dst_up_gate_contiguous.get(), (const float *)dst_up_gate_contiguous.get() + dst->ne[0],
(float *)dst->data, ggml_nelements(dst), dst->ne[0], src0_1->ne[1], src0_1->ne[1],
1.702f, 7.0f, stream);
} else {
Expand Down Expand Up @@ -3121,7 +3121,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
}
} else {
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(),
ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get(), (const float *)dst_up_contiguous.get() + dst->ne[0],
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row)/2, dst->ne[0], src0_1->ne[1], src0_1->ne[1],
1.702f, 7.0f, stream);
} else {
Expand Down
9 changes: 5 additions & 4 deletions ggml/src/ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ static __global__ void fused_mul_silu_f32(const float * x, float * dst, const in
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0]));
//dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0]));
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j]));
}

static __global__ void fused_mul_silu_f32(const float * x, float * dst, const int k, const int ne0, float limit) {
Expand Down Expand Up @@ -148,7 +149,7 @@ static __global__ void fused_mul_relu_f32(const float * x, float * dst, const in
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
dst[i] = fmaxf(x_row[j + ne0], 0) * x_row[j];
dst[i] = fmaxf(x_row[j], 0) * x_row[j + ne0];
}

static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
Expand All @@ -174,8 +175,8 @@ static __global__ void fused_mul_gelu_f32(const float * x, float * dst, const in
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
float xi = x_row[j + ne0];
dst[i] = 0.5f*xi*x_row[j]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
float xi = x_row[j];
dst[i] = 0.5f*xi*x_row[j+ne0]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}

static __global__ void tanh_f32(const float * x, float * dst, int k) {
Expand Down
21 changes: 14 additions & 7 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -17448,12 +17448,20 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
continue;
}

const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
const char * src0_2_cur = src0_2 ? (const char *) src0_2->data + cur_a*nb02 : src0_1_cur + nb02/2;
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
if (up_b_cur && !gate_b_cur) {
gate_b_cur = up_b_cur + nb41/2;
const char *src0_1_cur, *src0_2_cur, *up_b_cur = NULL, *gate_b_cur = NULL;
if (src0_2) {
src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
} else {
src0_2_cur = (const char *) src0_1->data + cur_a*nb02;
src0_1_cur = src0_2_cur + nb02/2;
if (up_b) {
GGML_ASSERT(!gate_b);
gate_b_cur = (const char *)up_b->data + cur_a*nb41;
up_b_cur = gate_b_cur + nb41/2;
}
}

const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
Expand All @@ -17462,7 +17470,6 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const int64_t nr0 = src0_2 ? ne01 : ne01/2; // src0 rows
const int64_t nr1 = cne1; // src1 rows

//if (ith == 0) printf("Calling iqk_moe_fused_up_gate with nr0 = %d, nr1 = %d, ne00 = %d, ne11 = %d\n", (int)nr0, (int)nr1, (int)ne00, (int)ne11);
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
type, src0_1_cur, src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
Expand Down
4 changes: 2 additions & 2 deletions src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,8 +1262,8 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * up_gate_exps, ggml_tensor * up_gate_exps_b,
ggml_tensor * shexp_gate) {

auto split_up_exps = (ggml_split_tensor_t *)up_exps->extra;
auto split_gate_exps = (ggml_split_tensor_t *)gate_exps->extra;
auto split_up_exps = up_exps ? (ggml_split_tensor_t *)up_exps->extra : nullptr;
auto split_gate_exps = gate_exps ? (ggml_split_tensor_t *)gate_exps->extra : nullptr;
auto split_down_exps = (ggml_split_tensor_t *)down_exps->extra;
auto split_up_shexp = up_shexp ? (ggml_split_tensor_t *)up_shexp->extra : nullptr;
auto split_gate_shexp = gate_shexp ? (ggml_split_tensor_t *)gate_shexp->extra : nullptr;
Expand Down
63 changes: 43 additions & 20 deletions src/llama-load-tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3034,15 +3034,27 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {

ggml_context *ctx_ffn_gate, *ctx_ffn_up, *ctx_ffn_down;
layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
bool merged = ml.merge_up_gate_exps && merge_up_gate_exps(tn, i, 2);
use_mmap_buffer &= !merged;
if (merged) {
bool merged = false;
auto ug_name = tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i);
auto ug_meta = ml.get_tensor_meta(ug_name.c_str());
if (ug_meta) {
auto ug_name_b = tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "bias", i);
auto ug_meta_b = ml.get_tensor_meta(ug_name_b.c_str());
GGML_ASSERT(ug_meta_b);
layer.ffn_up_gate_exps = create_tensor(ctx_split, ug_name, { ug_meta->ne[0], ug_meta->ne[1], ug_meta->ne[2] }, 0);
layer.ffn_up_gate_exps_b = create_tensor(ctx_split, ug_name_b, { ug_meta_b->ne[0], ug_meta_b->ne[1], ug_meta_b->ne[2] }, 0);
ctx_ffn_gate = ctx_ffn_up = ctx_split;
} else {
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i),
{ n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up);
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i),
{ n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_gate);
merged = ml.merge_up_gate_exps && merge_up_gate_exps(tn, i, 2);
use_mmap_buffer &= !merged;
if (merged) {
ctx_ffn_gate = ctx_ffn_up = ctx_split;
} else {
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i),
{ n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up);
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i),
{ n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_gate);
}
}
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i),
{n_ff_exp, n_embd, n_expert}, 0, &ctx_ffn_down);
Expand All @@ -3053,7 +3065,7 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) {
auto ctx_gate_b = ctx_ffn_gate == ctx_split ? ctx_split : ctx_layer;
auto ctx_down_b = ctx_ffn_down == ctx_split ? ctx_split : ctx_layer;
auto ctx_up_b = ctx_ffn_up == ctx_split ? ctx_split : ctx_layer;
if (!merged) {
if (!ug_meta && !merged) {
layer.ffn_up_exps_b = create_tensor(ctx_up_b, tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_up_b);
layer.ffn_gate_exps_b = create_tensor(ctx_gate_b, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_gate_b);
}
Expand Down Expand Up @@ -3155,11 +3167,11 @@ bool create_tensors_helper::merge_up_gate_exps(const LLM_TN & tn, int i, int bia
LLAMA_LOG_INFO("%s: merging up/gate in layer %d\n", __func__, i);

layer.ffn_up_gate_exps = ggml_new_tensor_3d(u_ctx, u_meta->type, u_meta->ne[0], u_meta->ne[1] + g_meta->ne[1], u_meta->ne[2]);
snprintf(layer.ffn_up_gate_exps->name, GGML_MAX_NAME, "blk.%d.ffn_up_gate_exps.weight", i);
layer.ffn_up_exps = ml.create_tensor_as_view(u_ctx, layer.ffn_up_gate_exps, u_name.c_str(),
{ u_meta->ne[0], u_meta->ne[1], u_meta->ne[2] }, 0);
snprintf(layer.ffn_up_gate_exps->name, GGML_MAX_NAME, "blk.%d.ffn_gate_up_exps.weight", i);
layer.ffn_gate_exps = ml.create_tensor_as_view(u_ctx, layer.ffn_up_gate_exps, g_name.c_str(),
{ g_meta->ne[0], g_meta->ne[1], g_meta->ne[2] }, ggml_nbytes(layer.ffn_up_exps) ); //u_meta->ne[1]*u_meta->nb[1] );
{ g_meta->ne[0], g_meta->ne[1], g_meta->ne[2] }, 0);
layer.ffn_up_exps = ml.create_tensor_as_view(u_ctx, layer.ffn_up_gate_exps, u_name.c_str(),
{ u_meta->ne[0], u_meta->ne[1], u_meta->ne[2] }, ggml_nbytes(layer.ffn_gate_exps));

if (!bias) return true;

Expand All @@ -3180,11 +3192,11 @@ bool create_tensors_helper::merge_up_gate_exps(const LLM_TN & tn, int i, int bia
GGML_ASSERT(g_meta->ne[1] == g_meta_b->ne[0]);

layer.ffn_up_gate_exps_b = ggml_new_tensor_2d(ctx_split, u_meta_b->type, u_meta_b->ne[0] + g_meta_b->ne[0], u_meta->ne[1]);
snprintf(layer.ffn_up_gate_exps_b->name, GGML_MAX_NAME, "blk.%d.ffn_up_gate_exps.bias", i);
layer.ffn_up_exps_b = ml.create_tensor_as_view(ctx_split, layer.ffn_up_gate_exps_b, u_name_b.c_str(),
{ u_meta_b->ne[0], u_meta_b->ne[1] }, 0);
snprintf(layer.ffn_up_gate_exps_b->name, GGML_MAX_NAME, "blk.%d.ffn_gate_up_exps.bias", i);
layer.ffn_gate_exps_b = ml.create_tensor_as_view(ctx_split, layer.ffn_up_gate_exps_b, g_name_b.c_str(),
{ g_meta_b->ne[0], g_meta_b->ne[1] }, ggml_nbytes(layer.ffn_up_exps_b) ); //u_meta->nb[1]);
{ g_meta_b->ne[0], g_meta_b->ne[1] }, 0);
layer.ffn_up_exps_b = ml.create_tensor_as_view(ctx_split, layer.ffn_up_gate_exps_b, u_name_b.c_str(),
{ u_meta_b->ne[0], u_meta_b->ne[1] }, ggml_nbytes(layer.ffn_gate_exps_b));

return true;
}
Expand All @@ -3198,10 +3210,21 @@ bool create_tensors_helper::create_std_ffn_exps(int64_t n_embd, const LLM_TN & t
auto & layer = model.layers[i];
auto ffn_ctx = ctx_for_layer_split(i);

bool merged = flags == 0 && ml.merge_up_gate_exps && merge_up_gate_exps(tn, i, 0);
if (!merged) {
layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
bool merged = false;
auto ug_name = tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i);
auto ug_meta = ml.get_tensor_meta(ug_name.c_str());
//printf("Checking for tensor %s: %s\n", ug_name.c_str(), ug_meta ? "found" : "not found");
if (ug_meta) {
if (model.split_mode == LLAMA_SPLIT_MODE_ATTN || model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
GGML_ABORT("Merged ffn_up_exps/ffn_gate_exps are not supported for split mode graph!");
}
layer.ffn_up_gate_exps = create_tensor(ffn_ctx, ug_name, { n_embd, 2*n_ff_exp, n_expert}, flags);
} else {
merged = flags == 0 && ml.merge_up_gate_exps && merge_up_gate_exps(tn, i, 0);
if (!merged) {
layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
}
}
layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);

Expand Down
Loading