Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add option to override model tensor buffers #11397

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "arg.h"

#include "common.h"
#include "log.h"
#include "sampling.h"

Expand Down Expand Up @@ -321,6 +322,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.kv_overrides.back().key[0] = 0;
}

if (!params.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}

if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
Expand Down Expand Up @@ -1490,6 +1495,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
exit(0);
}
));
add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}

for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);

if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// FIXME: this leaks memory
params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
}
}
));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
Expand Down
10 changes: 10 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,22 +1084,32 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}

if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}

mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;

if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
mparams.kv_overrides = params.kv_overrides.data();
}

if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
} else {
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}

return mparams;
}

Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ struct common_params {
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
}

size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
if (tensor->ne[i] <= 0) {
return 0;
}
}

size_t nbytes;
const size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) {
Expand Down
8 changes: 8 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,18 @@ extern "C" {
};
};

struct llama_model_tensor_buft_override {
const char * pattern;
ggml_backend_buffer_type_t buft;
};

struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices;

// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;

int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs

Expand Down
5 changes: 4 additions & 1 deletion src/llama-model-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
std::vector<std::string> & splits,
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p) {
const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
int trace = 0;
if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE"));
Expand All @@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
}
}

tensor_buft_overrides = param_tensor_buft_overrides_p;

// Load the main GGUF
struct ggml_context * ctx = NULL;
struct gguf_init_params params = {
Expand Down
8 changes: 5 additions & 3 deletions src/llama-model-loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ struct llama_model_loader {

llama_mmaps mappings;

std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
const llama_model_tensor_buft_override * tensor_buft_overrides;

gguf_context_ptr meta;
std::vector<ggml_context_ptr> contexts;
Expand All @@ -95,7 +96,8 @@ struct llama_model_loader {
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p);
const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);

template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
Expand Down
23 changes: 21 additions & 2 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstring>
#include <functional>
#include <map>
#include <regex>
#include <sstream>
#include <stdexcept>

Expand Down Expand Up @@ -1460,9 +1461,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
}

ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
ggml_backend_buffer_type_t buft = nullptr;

// check overrides
if (ml.tensor_buft_overrides) {
std::string tensor_name = tn.str();
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
std::regex pattern(overrides->pattern);
if (std::regex_search(tensor_name, pattern)) {
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
buft = overrides->buft;
break;
}
}
}

if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
buft = select_weight_buft(hparams, t_meta, op, *buft_list);
if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
}
}

// avoid using a host buffer when using mmap
Expand Down Expand Up @@ -3781,6 +3799,7 @@ const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
Expand Down
2 changes: 1 addition & 1 deletion src/llama-quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}

std::vector<std::string> splits = {};
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
ml.init_mappings(false); // no prefetching

llama_model model(llama_model_default_params());
Expand Down
2 changes: 1 addition & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
model.t_start_us = tm.t_start_us;

try {
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);

ml.print_info();

Expand Down
Loading