Skip to content

Commit c884e71

Browse files
authored
move platt-transform expcetion to python (#270)
1 parent d259cbe commit c884e71

File tree

3 files changed

+33
-18
lines changed

3 files changed

+33
-18
lines changed

pecos/core/base.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -2049,12 +2049,12 @@ def link_calibrator_methods(self):
20492049
"""
20502050
corelib.fillprototype(
20512051
self.clib_float32.c_fit_platt_transform_f32,
2052-
None,
2052+
c_uint32,
20532053
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
20542054
)
20552055
corelib.fillprototype(
20562056
self.clib_float32.c_fit_platt_transform_f64,
2057-
None,
2057+
c_uint32,
20582058
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
20592059
)
20602060

@@ -2080,14 +2080,14 @@ def fit_platt_transform(self, logits, tgt_prob):
20802080
AB = np.array([0, 0], dtype=np.float64)
20812081

20822082
if tgt_prob.dtype == np.float32:
2083-
clib.clib_float32.c_fit_platt_transform_f32(
2083+
return_code = clib.clib_float32.c_fit_platt_transform_f32(
20842084
len(logits),
20852085
logits.ctypes.data_as(POINTER(c_float)),
20862086
tgt_prob.ctypes.data_as(POINTER(c_float)),
20872087
AB.ctypes.data_as(POINTER(c_double)),
20882088
)
20892089
elif tgt_prob.dtype == np.float64:
2090-
clib.clib_float32.c_fit_platt_transform_f64(
2090+
return_code = clib.clib_float32.c_fit_platt_transform_f64(
20912091
len(logits),
20922092
logits.ctypes.data_as(POINTER(c_double)),
20932093
tgt_prob.ctypes.data_as(POINTER(c_double)),
@@ -2096,7 +2096,20 @@ def fit_platt_transform(self, logits, tgt_prob):
20962096
else:
20972097
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")
20982098

2099-
return AB[0], AB[1]
2099+
PLATT_RETURN_CODE = {
2100+
"SUCCESS": 0,
2101+
"LINE_SEARCH_FAIL": 1,
2102+
"MAX_ITER_REACHED": 2,
2103+
}
2104+
2105+
if return_code == PLATT_RETURN_CODE["SUCCESS"]:
2106+
return AB[0], AB[1]
2107+
elif return_code == PLATT_RETURN_CODE["LINE_SEARCH_FAIL"]:
2108+
raise RuntimeError("fit_platt_transform: Line search fails")
2109+
elif return_code == PLATT_RETURN_CODE["MAX_ITER_REACHED"]:
2110+
raise RuntimeError("fit_platt_transform: Reaching maximal iterations")
2111+
else:
2112+
raise ValueError(f"Unknown return code {return_code}")
21002113

21012114

21022115
clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos")

pecos/core/libpecos.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -752,13 +752,13 @@ extern "C" {
752752
// ==== C Interface of Score Calibrator ====
753753

754754
#define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \
755-
void c_fit_platt_transform ## SUFFIX( \
755+
uint32_t c_fit_platt_transform ## SUFFIX( \
756756
size_t num_samples, \
757757
const VAL_TYPE* logits, \
758758
const VAL_TYPE* tgt_probs, \
759759
double* AB \
760760
) { \
761-
pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
761+
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
762762
}
763763
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
764764
C_FIT_PLATT_TRANSFORM(_f64, float64_t)

pecos/core/utils/newton.hpp

+13-11
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,14 @@ namespace pecos {
279279
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp
280280

281281
template <typename value_type>
282-
static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
282+
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
283+
// define the return code
284+
enum {
285+
SUCCESS=0,
286+
LINE_SEARCH_FAIL=1,
287+
MAX_ITER_REACHED=2,
288+
};
289+
283290
// hyper parameters
284291
int max_iter = 100; // Maximal number of iterations
285292
double min_step = 1e-10; // Minimal step taken in line search
@@ -292,14 +299,6 @@ namespace pecos {
292299
A = 0.0; B = 1.0;
293300
double fval = 0.0;
294301

295-
// check for out of bound in tgt_probs
296-
for (size_t i = 0; i < num_samples; i++) {
297-
if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) {
298-
throw std::runtime_error("fit_platt_transform: target probability out of bound\n");
299-
}
300-
}
301-
302-
303302
for (size_t i = 0; i < num_samples; i++) {
304303
double fApB = logits[i] * A + B;
305304
if (fApB >= 0) {
@@ -376,13 +375,16 @@ namespace pecos {
376375
}
377376

378377
if (stepsize < min_step) {
379-
throw std::runtime_error("fit_platt_transform: Line search fails\n");
378+
printf("WARNING: fit_platt_transform: Line search fails\n");
379+
return LINE_SEARCH_FAIL;
380380
}
381381
}
382382

383383
if (iter >= max_iter) {
384-
throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n");
384+
printf("WARNING: fit_platt_transform: Reaching maximal iterations\n");
385+
return MAX_ITER_REACHED;
385386
}
387+
return SUCCESS;
386388
}
387389
} // namespace pecos
388390
#endif

0 commit comments

Comments
 (0)