Skip to content

Commit

Permalink
Update _quantile_forest_fast.pyx
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 10, 2024
1 parent a030857 commit 2345dbd
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions quantile_forest/_quantile_forest_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ cpdef vector[double] calc_quantile(
if not issorted:
sort_cpp(inputs.begin(), inputs.end())

f = n_inputs + 1 - (2 * C)
f = <int>n_inputs + 1 - (2 * C)

out = vector[double](n_quantiles)

Expand All @@ -166,15 +166,15 @@ cpdef vector[double] calc_quantile(
idx = quantile * f + C - 1

# Check if the quantile is the first or last value.
if idx >= n_inputs - 1:
if idx >= <int>n_inputs - 1:
out[i] = inputs[n_inputs - 1]
continue
if idx <= 0:
out[i] = inputs[0]
continue

v_floor = inputs[<int>floor(idx)]
v_ceil = inputs[<int>ceil(idx)]
v_floor = inputs[<intp_t>floor(idx)]
v_ceil = inputs[<intp_t>ceil(idx)]

# Check if the quantile does not lie between two values.
if v_floor == v_ceil:
Expand All @@ -193,7 +193,7 @@ cpdef vector[double] calc_quantile(
out[i] = 0.5 * (v_floor + v_ceil)
elif s_interpolation == <char*>b"nearest":
if fabs(frac - 0.5) < 1e-16:
out[i] = inputs[<int>(round(idx / 2) * 2)]
out[i] = inputs[<intp_t>(round(idx / 2) * 2)]
else:
out[i] = v_floor if frac < 0.5 else v_ceil
elif s_interpolation == <char*>b"linear":
Expand Down Expand Up @@ -266,7 +266,7 @@ cpdef vector[double] calc_weighted_quantile(
cdef double f
cdef double quantile
cdef vector[double] cum_weights, sorted_quantile_indices
cdef int idx_floor, idx_ceil
cdef intp_t idx_floor, idx_ceil
cdef double p, p_floor, p_ceil
cdef double v_floor, v_ceil, frac
cdef vector[double] out
Expand Down Expand Up @@ -458,20 +458,20 @@ cpdef double calc_quantile_rank(
right = 0
for i in range(n_inputs):
if inputs[i] < score:
left = right = i + 1
left = right = <int>i + 1
elif inputs[i] == score:
right = i + 1
right = <int>i + 1
else:
break

if s_kind == <char*>b"rank":
out = (right + left + (1 if right > left else 0)) * 0.5 / n_inputs
out = (right + left + (1 if right > left else 0)) * 0.5 / <double>n_inputs
elif s_kind == <char*>b"weak":
out = right / (<double>n_inputs)
elif s_kind == <char*>b"strict":
out = left / (<double>n_inputs)
elif s_kind == <char*>b"mean":
out = (left + right) * 0.5 / n_inputs
out = (left + right) * 0.5 / <double>n_inputs

return out

Expand Down

0 comments on commit 2345dbd

Please sign in to comment.