From 3fdbdd9ac8790361b7f37b27faebe91cf1ad165f Mon Sep 17 00:00:00 2001 From: jiong-zhang Date: Wed, 6 Dec 2023 20:48:42 +0000 Subject: [PATCH] move platt-transform expcetion to python --- pecos/core/base.py | 23 ++++++++++++++++++----- pecos/core/libpecos.cpp | 4 ++-- pecos/core/utils/newton.hpp | 24 +++++++++++++----------- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/pecos/core/base.py b/pecos/core/base.py index 91109f9..36752d1 100644 --- a/pecos/core/base.py +++ b/pecos/core/base.py @@ -2049,12 +2049,12 @@ def link_calibrator_methods(self): """ corelib.fillprototype( self.clib_float32.c_fit_platt_transform_f32, - None, + c_uint32, [c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)], ) corelib.fillprototype( self.clib_float32.c_fit_platt_transform_f64, - None, + c_uint32, [c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)], ) @@ -2080,14 +2080,14 @@ def fit_platt_transform(self, logits, tgt_prob): AB = np.array([0, 0], dtype=np.float64) if tgt_prob.dtype == np.float32: - clib.clib_float32.c_fit_platt_transform_f32( + return_code = clib.clib_float32.c_fit_platt_transform_f32( len(logits), logits.ctypes.data_as(POINTER(c_float)), tgt_prob.ctypes.data_as(POINTER(c_float)), AB.ctypes.data_as(POINTER(c_double)), ) elif tgt_prob.dtype == np.float64: - clib.clib_float32.c_fit_platt_transform_f64( + return_code = clib.clib_float32.c_fit_platt_transform_f64( len(logits), logits.ctypes.data_as(POINTER(c_double)), tgt_prob.ctypes.data_as(POINTER(c_double)), @@ -2096,7 +2096,20 @@ def fit_platt_transform(self, logits, tgt_prob): else: raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}") - return AB[0], AB[1] + PLATT_RETURN_CODE = { + "SUCCESS": 0, + "LINE_SEARCH_FAIL": 1, + "MAX_ITER_REACHED": 2, + } + + if return_code == PLATT_RETURN_CODE["SUCCESS"]: + return AB[0], AB[1] + elif return_code == PLATT_RETURN_CODE["LINE_SEARCH_FAIL"]: + raise RuntimeError("fit_platt_transform: Line search fails") + elif return_code == PLATT_RETURN_CODE["MAX_ITER_REACHED"]: + raise RuntimeError("fit_platt_transform: Reaching maximal iterations") + else: + raise ValueError(f"Unknown return code {return_code}") clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos") diff --git a/pecos/core/libpecos.cpp b/pecos/core/libpecos.cpp index 599c4e5..3d62f97 100644 --- a/pecos/core/libpecos.cpp +++ b/pecos/core/libpecos.cpp @@ -752,13 +752,13 @@ extern "C" { // ==== C Interface of Score Calibrator ==== #define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \ - void c_fit_platt_transform ## SUFFIX( \ + uint32_t c_fit_platt_transform ## SUFFIX( \ size_t num_samples, \ const VAL_TYPE* logits, \ const VAL_TYPE* tgt_probs, \ double* AB \ ) { \ - pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \ + return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \ } C_FIT_PLATT_TRANSFORM(_f32, float32_t) C_FIT_PLATT_TRANSFORM(_f64, float64_t) diff --git a/pecos/core/utils/newton.hpp b/pecos/core/utils/newton.hpp index 0dcd68d..f6c088e 100644 --- a/pecos/core/utils/newton.hpp +++ b/pecos/core/utils/newton.hpp @@ -279,7 +279,14 @@ namespace pecos { // https://github.com/cjlin1/libsvm/blob/master/svm.cpp template - static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) { + uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) { + // define the return code + enum { + SUCCESS=0, + LINE_SEARCH_FAIL=1, + MAX_ITER_REACHED=2, + }; + // hyper parameters int max_iter = 100; // Maximal number of iterations double min_step = 1e-10; // Minimal step taken in line search @@ -292,14 +299,6 @@ namespace pecos { A = 0.0; B = 1.0; double fval = 0.0; - // check for out of bound in tgt_probs - for (size_t i = 0; i < num_samples; i++) { - if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) { - throw std::runtime_error("fit_platt_transform: target probability out of bound\n"); - } - } - - for (size_t i = 0; i < num_samples; i++) { double fApB = logits[i] * A + B; if (fApB >= 0) { @@ -376,13 +375,16 @@ namespace pecos { } if (stepsize < min_step) { - throw std::runtime_error("fit_platt_transform: Line search fails\n"); + printf("WARNING: fit_platt_transform: Line search fails\n"); + return LINE_SEARCH_FAIL; } } if (iter >= max_iter) { - throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n"); + printf("WARNING: fit_platt_transform: Reaching maximal iterations\n"); + return MAX_ITER_REACHED; } + return SUCCESS; } } // namespace pecos #endif