Skip to content

Commit

Permalink
threadpool: move all pause/resume logic into ggml
Browse files Browse the repository at this point in the history
  • Loading branch information
max-krasnyansky committed Aug 27, 2024
1 parent 3bcc4de commit adcd24c
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 88 deletions.
2 changes: 1 addition & 1 deletion examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,7 @@ int main(int argc, char ** argv) {
exit(1);
}

llama_attach_threadpool(ctx, threadpool);
llama_attach_threadpool(ctx, threadpool, NULL);

// warmup run
if (t.n_prompt > 0) {
Expand Down
9 changes: 2 additions & 7 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,6 @@ int main(int argc, char ** argv) {
exit(1);
}

llama_attach_batch_threadpool(ctx, threadpool_batch);
if (ctx_guidance) {
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
}

// Start the non-batch threadpool in the paused state
tpp.paused = true;
}
Expand All @@ -253,9 +248,9 @@ int main(int argc, char ** argv) {
exit(1);
}

llama_attach_threadpool(ctx, threadpool);
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
if (ctx_guidance) {
llama_attach_threadpool(ctx_guidance, threadpool);
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
}

const int n_ctx_train = llama_n_ctx_train(model);
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,12 @@ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_compute_th
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));

struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;

if (ctx->threadpool) {
// already had threadpool, pause/suspend it before switching
ggml_pause_threadpool(ctx->threadpool);
}

ctx->threadpool = threadpool;
}

Expand Down
3 changes: 0 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -19218,9 +19218,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
state->pending = false;

ggml_graph_compute_thread(state);
if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
break;
}
}
}

Expand Down
11 changes: 2 additions & 9 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,16 +431,9 @@ extern "C" {
// Optional: an auto threadpool gets created in ggml if not passed explicitly
LLAMA_API void llama_attach_threadpool(
struct llama_context * ctx,
ggml_compute_threadpool_t threadpool);
LLAMA_API void llama_attach_batch_threadpool(
struct llama_context * ctx,
ggml_compute_threadpool_t threadpool);
ggml_compute_threadpool_t threadpool,
ggml_compute_threadpool_t threadpool_batch);
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
LLAMA_API void llama_detach_batch_threadpool(struct llama_context * ctx);
LLAMA_API void llama_detach_threadpools(struct llama_context * ctx);

// Pauses all attached threadpools
LLAMA_API void llama_pause_threadpools(struct llama_context * ctx);

// Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_backend_free(void);
Expand Down
77 changes: 9 additions & 68 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15523,39 +15523,6 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
}

// Optionally swaps the batch and single-tok threadpools.
// Returns the number of threads, and if a valid threadpool exists, returns it too.
static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
llama_context & lctx,
int32_t n_tokens) {

const auto & cparams = lctx.cparams;
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;

ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool

// A batch threadpool without a non-batch threadpool isn't supported.
GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);

if (lctx.threadpool_batch && lctx.threadpool) {
// Switch between the 2 threadpools as needed
if (n_tokens > 1) {
ggml_pause_threadpool(lctx.threadpool);
threadpool = lctx.threadpool_batch;
n_threads = cparams.n_threads_batch;
} else {
ggml_pause_threadpool(lctx.threadpool_batch);
threadpool = lctx.threadpool;
n_threads = cparams.n_threads;
}
} else if (lctx.threadpool) {
threadpool = lctx.threadpool;
n_threads = cparams.n_threads;
}
return std::make_pair(n_threads, threadpool);
}


// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
Expand Down Expand Up @@ -15662,11 +15629,8 @@ static int llama_decode_internal(
lctx.n_outputs = n_outputs_new;
}

std::pair<int32_t, ggml_compute_threadpool_t> threads =
llama_swap_threadpools(lctx, n_tokens);

int n_threads = threads.first;
ggml_compute_threadpool_t threadpool = threads.second;
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;

GGML_ASSERT(n_threads > 0);

Expand Down Expand Up @@ -15906,11 +15870,9 @@ static int llama_encode_internal(
lctx.inp_embd_enc = NULL;
lctx.n_outputs = n_tokens;

std::pair<int32_t, ggml_compute_threadpool_t> threads =
llama_swap_threadpools(lctx, n_tokens);
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;

int n_threads = threads.first;
ggml_compute_threadpool_t threadpool = threads.second;
GGML_ASSERT(n_threads > 0);

ggml_backend_sched_reset(lctx.sched);
Expand Down Expand Up @@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {

void llama_attach_threadpool(
struct llama_context * ctx,
ggml_compute_threadpool_t threadpool) {
ctx->threadpool = threadpool;
}

void llama_attach_batch_threadpool(
struct llama_context * ctx,
ggml_compute_threadpool_t threadpool,
ggml_compute_threadpool_t threadpool_batch) {
ctx->threadpool_batch = threadpool_batch;
ctx->threadpool = threadpool;
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
}

void llama_detach_threadpool(struct llama_context * ctx) {
ctx->threadpool = nullptr;
}

void llama_detach_batch_threadpool(struct llama_context * ctx) {
ctx->threadpool = nullptr;
}

void llama_detach_threadpools(struct llama_context * ctx) {
llama_detach_threadpool(ctx);
llama_detach_batch_threadpool(ctx);
}

void llama_pause_threadpools(struct llama_context * ctx) {
if (ctx->threadpool) {
ggml_pause_threadpool(ctx->threadpool);
}
if (ctx->threadpool_batch) {
ggml_pause_threadpool(ctx->threadpool_batch);
}
ctx->threadpool = nullptr;
ctx->threadpool_batch = nullptr;
}

void llama_backend_free(void) {
Expand Down

0 comments on commit adcd24c

Please sign in to comment.