Skip to content

Commit

Permalink
Addressed issues with an extremely imbalanced and small dataset by re…
Browse files Browse the repository at this point in the history
…moving NaN values from the metrics. (This is a temporary fix.)
  • Loading branch information
HyunjunA committed Mar 12, 2024
1 parent 8f63ff5 commit 6a37ab8
Showing 1 changed file with 18 additions and 34 deletions.
52 changes: 18 additions & 34 deletions machine/learn/skl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def get_column_names_from_ColumnTransformer(column_transformer, feature_names):
new_feature_names += feature_columns
return new_feature_names

# decision rule for cross validation 2, 3, 4, 5, 6, 7, 8, 9, 10
def decision_rule_cv_based_on_classes(each_class):
# decision rule for choosing number of folds based on the class distribution in the given dataset
def decision_rule_fold_cv_based_on_classes(each_class):
"""
Adjusts the number of cross-validation folds based on the class distribution.
Expand All @@ -188,23 +188,19 @@ def decision_rule_cv_based_on_classes(each_class):
Returns
-------
cv : int
Adjusted number of cross-validation folds.
The suitable number of cross-validation folds ensuring that each fold can include instances of each class.
"""
# Find the class with the minimum number of samples based on the class sample counts
min_samples = min(each_class.values())
# Find the minimum class count to ensure every fold can contain at least one instance of every class.
min_class_count = min(each_class.values())

# Calculate the number of classes
n_classes = len(each_class)
# The maximum number of folds is determined by the smallest class to ensure representation in each fold.
# However, we cannot have more folds than the minimum class count.
n_folds = min(10, min_class_count) # Starting with a default max of 10 folds

# Determine the appropriate number of cv folds based on the class with the minimum samples
if n_classes == 2:
# For binary classification, ensure at least one sample of each class is present in the folds, to the extent possible
n_split = min(max(2, min_samples), 10)
else:
# For multi-class, use more folds if possible to balance between classes
n_split = min(max(3, min_samples), 10)
# Ensure at least 2 folds for meaningful cross-validation.
n_folds = max(n_folds, 2)

return n_split
return n_folds

def generate_results(model, input_data,
tmpdir, _id, target_name='class',
Expand Down Expand Up @@ -418,29 +414,17 @@ def generate_results(model, input_data,
target, cv, return_times=True)
model.fit(features, target)

# # plot learning curve
# plot_learning_curve(tmpdir,_id, model,features,target,cv,return_times=True)
# StratifiedKFold
# stratified_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Initialize RepeatedStratifiedKFold


# n_splits = 2
# n_repeats = 2
# stratified_cv = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=42)

# print("stratified_cv", stratified_cv)
# computing cross-validated metrics

# Temporary fix to handle NaN values
stratified_cv = StratifiedKFold(n_splits=8)


cv_scores = cross_validate(
estimator=model,
X=features,
y=target,
scoring=scoring,
cv = stratified_cv,
# cv = stratified_cv,
cv = cv,
return_train_score=True,
return_estimator=True
)
Expand Down Expand Up @@ -784,9 +768,9 @@ def plot_confusion_matrix(
None
"""
pred_y = np.empty(y.shape)
# cv = StratifiedKFold(n_splits=10)
cv = StratifiedKFold(n_splits=10)
# Temporary fix to handle NaN values
cv = StratifiedKFold(n_splits=8)
# cv = StratifiedKFold(n_splits=8)
for cv_split, est in zip(cv.split(X, y), cv_scores['estimator']):
train, test = cv_split
pred_y[test] = est.predict(X[test])
Expand Down Expand Up @@ -1079,9 +1063,9 @@ def plot_roc_curve(tmpdir, _id, X, y, cv_scores, figure_export):
"""
from scipy import interp
from scipy.stats import sem, t
# cv = StratifiedKFold(n_splits=10)
cv = StratifiedKFold(n_splits=10)
# Temporary fix to handle NaN values
cv = StratifiedKFold(n_splits=8)
# cv = StratifiedKFold(n_splits=8)
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
Expand Down

0 comments on commit 6a37ab8

Please sign in to comment.