Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH more stable gradient of CrossEntropy #6327

Merged
merged 9 commits into from
Feb 22, 2024
45 changes: 39 additions & 6 deletions src/objective/xentropy_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,54 @@ class CrossEntropy: public ObjectiveFunction {
}

void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
// z = expit(score) = 1 / (1 + exp(-score))
// gradient = z - label = expit(score) - label
// Numerically more stable, see http://fa.bianp.net/blog/2019/evaluate_logistic/
// if score < 0:
// exp_tmp = exp(score)
// return ((1 - label) * exp_tmp - label) / (1 + exp_tmp)
// else:
// exp_tmp = exp(-score)
// return ((1 - label) - label * exp_tmp) / (1 + exp_tmp)
// Note that optimal speed would be achieved, at the cost of precision, by
// return expit(score) - y_true
// i.e. no "if else" and an own inline implementation of expit.
// The case distinction score < 0 in the stable implementation does not
// provide significant better precision apart from protecting overflow of exp(..).
// The branch (if else), however, can incur runtime costs of up to 30%.
// Instead, we help branch prediction by almost always ending in the first if clause
// and making the second branch (else) a bit simpler. This has the exact same
// precision but is faster than the stable implementation.
// As branching criteria, we use the same cutoff as in log1pexp, see link above.
// Note that the maximal value to get gradient = -1 with label = 1 is -37.439198610162731
// (based on mpmath), and scipy.special.logit(np.finfo(float).eps) ~ -36.04365.
if (weights_ == nullptr) {
// compute pointwise gradients and Hessians with implied unit weights
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
const double z = 1.0f / (1.0f + std::exp(-score[i]));
gradients[i] = static_cast<score_t>(z - label_[i]);
hessians[i] = static_cast<score_t>(z * (1.0f - z));
if (score[i] > -37.0) {
const double exp_tmp = std::exp(-score[i]);
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp));
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)));
} else {
const double exp_tmp = std::exp(score[i]);
gradients[i] = static_cast<score_t>(exp_tmp - label_[i]);
hessians[i] = static_cast<score_t>(exp_tmp);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hessians[i] = static_cast<score_t>(exp_tmp);
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed as exp_tmp < 1e-16 is tiny and (1 + exp_tmp) is just 1. Otherwise stated, the implemented formula is the 1st order Taylor series in exp_tmp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe it would still be better to write the original calculation formula explicitly to avoid ambiguity?

Copy link
Contributor Author

@lorentzenchr lorentzenchr Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with "ambiguity"?
It would not avoid the branch and is a tiny bit more efficient.

}
}
} else {
// compute pointwise gradients and Hessians with given weights
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
const double z = 1.0f / (1.0f + std::exp(-score[i]));
gradients[i] = static_cast<score_t>((z - label_[i]) * weights_[i]);
hessians[i] = static_cast<score_t>(z * (1.0f - z) * weights_[i]);
if (score[i] > -37.0) {
const double exp_tmp = std::exp(-score[i]);
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp) * weights_[i]);
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)) * weights_[i]);
} else {
const double exp_tmp = std::exp(score[i]);
gradients[i] = static_cast<score_t>((exp_tmp - label_[i]) * weights_[i]);
hessians[i] = static_cast<score_t>(exp_tmp * weights_[i]);
lorentzenchr marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
Expand Down
Loading