Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build-sanitize-thread/
build-cov/
build-ci-debug/
build-ci-release/
build-cublas/
out/
tmp/
models/
Expand All @@ -15,6 +16,7 @@ compile_commands.json
CMakeSettings.json
.vs/
.vscode/
.clangd

.exrc
.cache
Expand All @@ -32,4 +34,4 @@ zig-cache/

*.sw?

__pycache__/
__pycache__/
10 changes: 10 additions & 0 deletions examples/gpt-2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
set(TEST_TARGET gpt-2-quantize)
add_executable(${TEST_TARGET} quantize.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

#
# For GPU offloading

if (GGML_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS)
endif()
if (GGML_CLBLAST)
add_compile_definitions(GGML_USE_CLBLAST)
endif()
175 changes: 127 additions & 48 deletions examples/gpt-2/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "ggml/ggml.h"
#include "ggml/ggml-alloc.h"
#include "ggml/ggml-backend.h"

#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#endif

#include "common.h"
#include "common-ggml.h"
Expand Down Expand Up @@ -70,11 +75,17 @@ struct gpt2_model {

//
struct ggml_context * ctx;

ggml_backend_t backend = NULL;

ggml_backend_buffer_t buffer_w;
ggml_backend_buffer_t buffer_kv;

std::map<std::string, struct ggml_tensor *> tensors;
};

// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());

auto fin = std::ifstream(fname, std::ios::binary);
Expand Down Expand Up @@ -155,7 +166,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &

auto & ctx = model.ctx;

size_t ctx_size = 0;
size_t buffer_size = 0;

{
const auto & hparams = model.hparams;
Expand All @@ -165,46 +176,44 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b

ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
buffer_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b

ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b

ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
buffer_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
buffer_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
buffer_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
buffer_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
buffer_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b

ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
buffer_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b

ctx_size += (6 + 12*n_layer)*512; // object overhead
buffer_size += (6 + 12*n_layer)*128; // alignment overhead

printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0));
}

// create the ggml context
{
size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_size =*/ ggml_tensor_overhead() * n_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
/*.no_alloc =*/ true,
};

model.ctx = ggml_init(params);
Expand All @@ -214,6 +223,31 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}
}

// initialize the backend
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > 0) {
fprintf(stderr, "%s: using CUDA backend\n", __func__);
model.backend = ggml_backend_cuda_init();
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif

if (!model.backend) {
// fallback to CPU backend
fprintf(stderr, "%s: using CPU backend\n", __func__);
model.backend = ggml_backend_cpu_init();
}

if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_cpu_init() failed\n", __func__);
return false;
}

// allocate weights buffer
model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size);

// prepare memory for the weights
{
const auto & hparams = model.hparams;
Expand Down Expand Up @@ -299,14 +333,34 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);

printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);

// create a backend buffer (can be in host or device memory)
model.buffer_kv = ggml_backend_alloc_buffer(model.backend, memory_size + 256);

// allocate the tensors into the backend buffer
// TODO: better API for this
{
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_kv);

// this updates the pointers in the tensors to point to the correct location in the buffer
// this is necessary since the ggml_context is .no_alloc == true
ggml_allocr_alloc(alloc, model.memory_k);
ggml_allocr_alloc(alloc, model.memory_v);

ggml_allocr_free(alloc);
}
}

// load weights
{
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_w);

size_t total_size = 0;

bool has_lm_head = false;

std::vector<char> read_buf;

while (true) {
int32_t n_dims;
int32_t length;
Expand Down Expand Up @@ -336,6 +390,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}

auto tensor = model.tensors[name];
ggml_set_name(tensor, name.c_str());
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
return false;
Expand All @@ -360,11 +415,23 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
return false;
}

fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
ggml_allocr_alloc(alloc, tensor);

if (ggml_backend_is_cpu(model.backend)) {
// for the CPU backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
fin.read(read_buf.data(), ggml_nbytes(tensor));
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
}

// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
//ggml_allocr_alloc(alloc, model.lm_head);
//ggml_backend_tensor_copy(tensor, model.lm_head);
model.lm_head = tensor;
}

if (name == "model/lm_head") {
Expand All @@ -374,6 +441,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
total_size += ggml_nbytes(tensor);
}

ggml_allocr_free(alloc);
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}

Expand Down Expand Up @@ -416,21 +484,23 @@ struct ggml_cgraph * gpt2_graph(

// avoid writing to tensors if we are only measuring the memory usage
if (!ggml_allocr_is_measure(allocr)) {
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd));
}

struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(allocr, position);
if (!ggml_allocr_is_measure(allocr)) {
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
int32_t v = n_past + i;
ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v));
}
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(allocr, KQ_scale);
if (!ggml_allocr_is_measure(allocr)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
float s = 1.0f/sqrtf(float(n_embd)/n_head);
ggml_backend_tensor_set(KQ_scale, &s, 0, sizeof(s));
}

// wte + wpe
Expand All @@ -453,7 +523,8 @@ struct ggml_cgraph * gpt2_graph(
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
//ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
model.layers[il].ln_1_b);
}

// attn
Expand Down Expand Up @@ -599,7 +670,8 @@ struct ggml_cgraph * gpt2_graph(
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
//ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
model.layers[il].ln_2_b);
}

// fully connected
Expand Down Expand Up @@ -654,7 +726,8 @@ struct ggml_cgraph * gpt2_graph(
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
//ggml_repeat(ctx0, model.ln_f_b, inpL));
model.ln_f_b);
}

// inpL = WTE * inpL
Expand Down Expand Up @@ -703,11 +776,10 @@ bool gpt2_eval(
ggml_allocr_alloc_graph(allocr, gf);

// run the computation
struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
static std::vector<uint8_t> work_buffer;
work_buffer.resize(plan.work_size);
plan.work_data = work_buffer.data();
ggml_graph_compute(gf, &plan);
if (ggml_backend_is_cpu(model.backend)) {
ggml_backend_cpu_set_n_threads(model.backend, n_threads);
}
ggml_backend_graph_compute(model.backend, gf);

//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
Expand All @@ -718,11 +790,11 @@ bool gpt2_eval(
struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];

//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
//ggml_backend_tensor_get(inpL, embd_w.data(), 0, sizeof(float)*n_vocab*N);

// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
ggml_backend_tensor_get(inpL, embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab);

return true;
}
Expand Down Expand Up @@ -759,7 +831,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();

if (!gpt2_model_load(params.model, model, vocab)) {
if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand All @@ -770,25 +842,27 @@ int main(int argc, char ** argv) {
}

// keep this buffer alive while evaluating the model
std::vector<uint8_t> compute_buffer;
ggml_backend_buffer_t buf_compute;

struct ggml_allocr * allocr = NULL;
// allocate the compute buffer
{
allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN);
// alignment required by the backend
size_t align = ggml_backend_get_alignment(model.backend);
allocr = ggml_allocr_new_measure(align);

// create the worst case graph for memory usage estimation
int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
int n_past = model.hparams.n_ctx - n_tokens;
struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));

// compute the required memory
size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN;
size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);

// recreate the allocator with the required memory
ggml_allocr_free(allocr);
compute_buffer.resize(mem_size);
allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN);
buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
allocr = ggml_allocr_new_from_buffer(buf_compute);

fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
}
Expand Down Expand Up @@ -888,5 +962,10 @@ int main(int argc, char ** argv) {

ggml_free(model.ctx);

ggml_backend_buffer_free(model.buffer_w);
ggml_backend_buffer_free(model.buffer_kv);
ggml_backend_buffer_free(buf_compute);
ggml_backend_free(model.backend);

return 0;
}
Loading