Skip to content

Commit

Permalink
Refactor Cython variables
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 7, 2024
1 parent 86ea424 commit a030857
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions quantile_forest/_quantile_forest_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ cdef class QuantileForest:
cdef intp_t n_trees, n_train, max_idx
cdef intp_t i, j, k, l
cdef bint use_mean
cdef double q
cdef list[char*] interpolations
cdef vector[double] leaf_samples
cdef vector[double] leaf_weights
Expand All @@ -615,7 +616,6 @@ cdef class QuantileForest:
cdef double train_wgt
cdef vector[int] n_leaf_samples
cdef int n_total_samples, n_total_trees
cdef double train_weight
cdef vector[vector[double]] leaf_preds
cdef vector[double] pred
cdef cnp.ndarray[float64_t, ndim=3] preds
Expand Down Expand Up @@ -702,15 +702,15 @@ cdef class QuantileForest:
for k in range(n_trees):
if X_indices is None or X_indices[i, k] is True:
idx = 0 if aggregate_leaves_first else k
train_weight = 1
train_wgt = 1
if weighted_leaves:
train_weight = 0
train_wgt = 0
if n_leaf_samples[k] > 0:
train_weight = 1 / <double>n_leaf_samples[k]
train_weight *= <double>n_total_samples
train_weight /= <double>n_total_trees
train_wgt = 1 / <double>n_leaf_samples[k]
train_wgt *= <double>n_total_samples
train_wgt /= <double>n_total_trees
train_weights[idx].insert(
train_weights[idx].end(), max_idx, train_weight
train_weights[idx].end(), max_idx, train_wgt
)

# For each list of training indices, calculate output.
Expand Down

0 comments on commit a030857

Please sign in to comment.