Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add C implementation to fit Platt scaling #266

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(self, dirname, soname, forced_rebuild=False):
self.link_ann_hnsw_methods()
self.link_mmap_hashmap_methods()
self.link_mmap_valstore_methods()
self.link_calibrator_methods()

def link_xlinear_methods(self):
"""
Expand Down Expand Up @@ -1939,5 +1940,60 @@ def mmap_valstore_init(self, store_type):
raise NotImplementedError(f"store_type={store_type} is not implemented.")
return self.mmap_valstore_fn_dict[store_type]

def link_calibrator_methods(self):
"""
Specify C-lib's score calibration methods arguments and return types.
"""
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
None,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
None,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
)

def fit_platt_transform(self, logits, tgt_prob):
"""Python to C/C++ interface for platt transfrom fit.

Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf

Args:
logits (ndarray): 1-d array of logit with length N.
tgt_prob (ndarray): 1-d array of target probability scores within [0, 1] with length N.
Returns:
A, B: coefficients for Platt's scale.
"""
assert isinstance(logits, np.ndarray)
assert isinstance(tgt_prob, np.ndarray)
assert len(logits) == len(tgt_prob)
assert logits.dtype == tgt_prob.dtype

if tgt_prob.min() < 0 or tgt_prob.max() > 1.0:
raise ValueError("Target probability out of bound!")

AB = np.array([0, 0], dtype=np.float64)

if tgt_prob.dtype == np.float32:
clib.clib_float32.c_fit_platt_transform_f32(
len(logits),
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
)
elif tgt_prob.dtype == np.float64:
clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
AB.ctypes.data_as(POINTER(c_double)),
)
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")

return AB[0], AB[1]


clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos")
14 changes: 14 additions & 0 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,4 +651,18 @@ extern "C" {
static_cast<mmap_valstore_bytes *>(map_ptr)->batch_get(
n_sub_row, n_sub_col, sub_rows, sub_cols, trunc_val_len, ret, ret_lens, threads);
}

// ==== C Interface of Score Calibrator ====

#define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \
void c_fit_platt_transform ## SUFFIX( \
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
) { \
pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
}
112 changes: 112 additions & 0 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,117 @@ namespace pecos {
return cg_iter;
};
};


// Platt scale with given target curve.
// Reference Implementation:
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-6;

int iter;

// Initial Point and Initial Fun Value
A = 0.0; B = 1.0;
double fval = 0.0;

// check for out of bound in tgt_probs
for (size_t i = 0; i < num_samples; i++) {
if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) {
throw std::runtime_error("fit_platt_transform: target probability out of bound\n");
}
}


for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
if (fApB >= 0) {
fval += tgt_probs[i] * fApB + log(1 + exp(-fApB));
} else {
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
double h11 = sigma;
double h22 = sigma; // numerically ensures strict PD
double h21 = 0.0;
double g1 = 0.0;
double g2 = 0.0;

for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
double p = 0, q = 0;
if (fApB >= 0) {
p = exp(-fApB) / (1.0 + exp(-fApB));
q = 1.0 / (1.0 + exp(-fApB));
} else {
p = 1.0 / (1.0 + exp(fApB));
q = exp(fApB) / (1.0 + exp(fApB));
}
double d1 = tgt_probs[i] - p;
double d2 = p * q;

h11 += d2 * logits[i] * logits[i];
h22 += d2;
h21 += logits[i] * d2;
g1 += logits[i] * d1;
g2 += d1;
}

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps)
break;

// Finding Newton direction: -inv(H') * g
double det = h11 * h22 - h21 * h21;
double dA = -(h22 * g1 - h21 * g2) / det;
double dB = -(-h21 * g1 + h11 * g2) / det;
double gd = g1 * dA + g2 * dB;

// Line Search
double stepsize = 1.0;

while (stepsize >= min_step) {
double newA = A + stepsize * dA;
double newB = B + stepsize * dB;

// New function value
double newf = 0.0;
for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * newA + newB;
if (fApB >= 0) {
newf += tgt_probs[i] * fApB + log(1 + exp(-fApB));
} else {
newf += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
// Check sufficient decrease
if (newf < fval + 0.0001 * stepsize * gd)
{
A = newA;
B = newB;
fval = newf;
break;
} else {
stepsize = stepsize / 2.0;
}
}

if (stepsize < min_step) {
throw std::runtime_error("fit_platt_transform: Line search fails\n");
}
}

if (iter >= max_iter) {
throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n");
}
}
} // namespace pecos
#endif
20 changes: 20 additions & 0 deletions test/pecos/core/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,23 @@ def test_sparse_inner_products():
assert true_vals == approx(
pred_vals, abs=1e-9
), f"true_vals != pred_vals, where X/Y are drm/dcm"


def test_platt_scale():
import numpy as np
from pecos.core import clib

A = 0.25
B = 3.14

orig = np.arange(-15, 15, 1, dtype=np.float32)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float32)
At, Bt = clib.fit_platt_transform(orig, tgt)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

orig = np.arange(-15, 15, 1, dtype=np.float64)
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float64)
At, Bt = clib.fit_platt_transform(orig, tgt)
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"
Loading