Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,7 @@ extern "C" {
GGML_OPT_NO_CONTEXT,
GGML_OPT_INVALID_WOLFE,
GGML_OPT_FAIL,
GGML_OPT_CANCEL,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xaedes I've added the GGML_OPT_CANCEL return code and simplified the cancellation logic during optimization

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of sense.


GGML_LINESEARCH_FAIL = -128,
GGML_LINESEARCH_MINIMUM_STEP,
Expand Down
29 changes: 7 additions & 22 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -19762,7 +19762,7 @@ static enum ggml_opt_result ggml_opt_adam(
if (callback) {
callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
break;
return GGML_OPT_CANCEL;
}
}
// ggml_graph_reset (gf);
Expand All @@ -19771,9 +19771,6 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_get_f32_1d(f, 0);
}
if (cancel) {
return GGML_OPT_DID_NOT_CONVERGE;
}
fx *= accum_norm;

opt->adam.fx_prev = fx;
Expand All @@ -19799,9 +19796,6 @@ static enum ggml_opt_result ggml_opt_adam(

// run the optimizer
for (int t = 0; t < params.adam.n_iter; ++t) {
if (cancel) {
break;
}
opt->iter = iter0 + t + 1;
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);

Expand Down Expand Up @@ -19859,7 +19853,7 @@ static enum ggml_opt_result ggml_opt_adam(
if (callback) {
callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
break;
return GGML_OPT_CANCEL;;
}
}
// ggml_graph_reset (gf);
Expand All @@ -19868,9 +19862,6 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_get_f32_1d(f, 0);
}
if (cancel) {
break;
}
fx *= accum_norm;

opt->loss_after = fx;
Expand Down Expand Up @@ -19989,7 +19980,7 @@ static enum ggml_opt_result linesearch_backtracking(
finit = *fx;
dgtest = params->lbfgs.ftol*dginit;

while (!*cancel) {
while (true) {
ggml_vec_cpy_f32(nx, x, xp);
ggml_vec_mad_f32(nx, x, d, *step);

Expand All @@ -20005,7 +19996,7 @@ static enum ggml_opt_result linesearch_backtracking(
float sched = 0;
callback(callback_data, accum_step, &sched, cancel);
if (*cancel) {
break;
return GGML_OPT_CANCEL;
}
}
// ggml_graph_reset (gf);
Expand All @@ -20014,9 +20005,6 @@ static enum ggml_opt_result linesearch_backtracking(
ggml_opt_acc_grad(np, ps, g, accum_norm);
*fx += ggml_get_f32_1d(f, 0);
}
if (*cancel) {
break;
}
*fx *= accum_norm;

}
Expand Down Expand Up @@ -20149,7 +20137,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
float sched = 0;
callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
break;
return GGML_OPT_CANCEL;
}
}
// ggml_graph_reset (gf);
Expand All @@ -20158,9 +20146,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_get_f32_1d(f, 0);
}
if (cancel) {
return GGML_OPT_DID_NOT_CONVERGE;
}
fx *= accum_norm;

opt->loss_before = fx;
Expand Down Expand Up @@ -20220,8 +20205,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_vec_cpy_f32(nx, gp, g);

ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here instead of passing &cancel, we should check the return code if it matches GGML_OPT_CANCEL

if (!cancel) {
break;
if (cancel) {
return GGML_OPT_CANCEL;
}

if (ls < 0) {
Expand Down