Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.mla_attn = std::stoi(argv[i]);
return true;
}
if (arg == "-amb" || arg == "--attention-max-batch") {
CHECK_ARG
params.attn_max_batch = std::stoi(argv[i]);
return true;
}
if (arg == "-fmoe" || arg == "--fused-moe") {
params.fused_moe_up_gate = true;
return true;
Expand Down Expand Up @@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
Expand Down Expand Up @@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
Expand Down Expand Up @@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
Expand Down
35 changes: 32 additions & 3 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ struct cmd_params {
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<int> mla_attn;
std::vector<int> attn_max_batch;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
Expand Down Expand Up @@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = {
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* mla_attn */ {0},
/* attn_max_batch */ {0},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
Expand Down Expand Up @@ -301,6 +303,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
Expand Down Expand Up @@ -578,6 +581,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<int>(argv[i], split_delim);
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-amb" || arg == "--attn-max-batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<int>(argv[i], split_delim);
params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -690,6 +700,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
Expand Down Expand Up @@ -727,6 +738,7 @@ struct cmd_params_instance {
bool no_kv_offload;
bool flash_attn;
int mla_attn;
int attn_max_batch;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -773,6 +785,7 @@ struct cmd_params_instance {
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.mla_attn = mla_attn;
cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe;
cparams.embeddings = embeddings;

Expand All @@ -799,6 +812,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & mla : params.mla_attn)
for (const auto & amb : params.attn_max_batch)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
Expand All @@ -821,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -852,6 +867,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -883,6 +899,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -914,6 +931,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -956,6 +974,7 @@ struct test {
bool no_kv_offload;
bool flash_attn;
int mla_attn;
int attn_max_batch;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -987,6 +1006,7 @@ struct test {
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
mla_attn = inst.mla_attn;
attn_max_batch = inst.attn_max_batch;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
Expand Down Expand Up @@ -1081,7 +1101,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
Expand All @@ -1097,7 +1117,7 @@ struct test {
field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" ||
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
Expand Down Expand Up @@ -1138,7 +1158,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
Expand Down Expand Up @@ -1305,6 +1325,9 @@ struct markdown_printer : public printer {
if (field == "mla_attn") {
return 3;
}
if (field == "attn_max_batch") {
return 5;
}
if (field == "use_mmap") {
return 4;
}
Expand Down Expand Up @@ -1345,6 +1368,9 @@ struct markdown_printer : public printer {
if (field == "mla_attn") {
return "mla";
}
if (field == "attn_max_batch") {
return "amb";
}
if (field == "use_mmap") {
return "mmap";
}
Expand Down Expand Up @@ -1403,6 +1429,9 @@ struct markdown_printer : public printer {
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
fields.emplace_back("mla_attn");
}
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
fields.emplace_back("attn_max_batch");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
Expand Down
30 changes: 23 additions & 7 deletions ggml/src/ggml-cuda/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

float * dst_d = (float *)dst->data;

if (dim != 3) {
if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
} else {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
Expand All @@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
} else {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);

CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
}

//if (dim != 3) {
// for (int i3 = 0; i3 < dst->ne[3]; i3++) {
// concat_f32_cuda(
// src0_d + i3 * (src0->nb[3] / 4),
// src1_d + i3 * (src1->nb[3] / 4),
// dst_d + i3 * ( dst->nb[3] / 4),
// src0->ne[0], src0->ne[1], src0->ne[2],
// dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
// }
//} else {
// const size_t size0 = ggml_nbytes(src0);
// const size_t size1 = ggml_nbytes(src1);

// CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
// CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
//}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
Expand Down
20 changes: 20 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12627,6 +12627,26 @@ static void ggml_compute_forward_concat_f32(

GGML_ASSERT(dim >= 0 && dim < 4);

if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) &&
(dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
// simply copy the data
const int64_t size_src_0 = ggml_nbytes(src0);
const int64_t size_src_1 = ggml_nbytes(src1);
const int64_t block_size = 4096;
const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size;
for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) {
const int64_t start = i_block*block_size;
if (start < size_src_0) {
int64_t copy_size = MIN(block_size, size_src_0 - start);
memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size);
} else {
int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start);
memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size);
}
}
return;
}

int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];

Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ extern "C" {
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]

// Abort callback
Expand Down
Loading