Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move platt-transform expcetion to python #270

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2049,12 +2049,12 @@ def link_calibrator_methods(self):
"""
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
None,
c_uint32,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
None,
c_uint32,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
)

Expand All @@ -2080,14 +2080,14 @@ def fit_platt_transform(self, logits, tgt_prob):
AB = np.array([0, 0], dtype=np.float64)

if tgt_prob.dtype == np.float32:
clib.clib_float32.c_fit_platt_transform_f32(
return_code = clib.clib_float32.c_fit_platt_transform_f32(
len(logits),
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
)
elif tgt_prob.dtype == np.float64:
clib.clib_float32.c_fit_platt_transform_f64(
return_code = clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
Expand All @@ -2096,7 +2096,20 @@ def fit_platt_transform(self, logits, tgt_prob):
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")

return AB[0], AB[1]
PLATT_RETURN_CODE = {
"SUCCESS": 0,
"LINE_SEARCH_FAIL": 1,
"MAX_ITER_REACHED": 2,
}

if return_code == PLATT_RETURN_CODE["SUCCESS"]:
return AB[0], AB[1]
elif return_code == PLATT_RETURN_CODE["LINE_SEARCH_FAIL"]:
raise RuntimeError("fit_platt_transform: Line search fails")
elif return_code == PLATT_RETURN_CODE["MAX_ITER_REACHED"]:
raise RuntimeError("fit_platt_transform: Reaching maximal iterations")
else:
raise ValueError(f"Unknown return code {return_code}")


clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos")
4 changes: 2 additions & 2 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,13 +752,13 @@ extern "C" {
// ==== C Interface of Score Calibrator ====

#define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \
void c_fit_platt_transform ## SUFFIX( \
uint32_t c_fit_platt_transform ## SUFFIX( \
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
) { \
pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
Expand Down
24 changes: 13 additions & 11 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,14 @@ namespace pecos {
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
// define the return code
enum {
SUCCESS=0,
LINE_SEARCH_FAIL=1,
MAX_ITER_REACHED=2,
};

// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
Expand All @@ -292,14 +299,6 @@ namespace pecos {
A = 0.0; B = 1.0;
double fval = 0.0;

// check for out of bound in tgt_probs
for (size_t i = 0; i < num_samples; i++) {
if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) {
throw std::runtime_error("fit_platt_transform: target probability out of bound\n");
OctoberChang marked this conversation as resolved.
Show resolved Hide resolved
}
}


for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
if (fApB >= 0) {
Expand Down Expand Up @@ -376,13 +375,16 @@ namespace pecos {
}

if (stepsize < min_step) {
throw std::runtime_error("fit_platt_transform: Line search fails\n");
printf("WARNING: fit_platt_transform: Line search fails\n");
return LINE_SEARCH_FAIL;
}
}

if (iter >= max_iter) {
throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n");
printf("WARNING: fit_platt_transform: Reaching maximal iterations\n");
return MAX_ITER_REACHED;
}
return SUCCESS;
}
} // namespace pecos
#endif
Loading