Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion imblearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
report_dict_label[headers[-1]] = support[i]
report += fmt % tuple(values)

report_dict[label] = report_dict_label
report_dict[target_names[i]] = report_dict_label

report += "\n"

Expand Down
42 changes: 37 additions & 5 deletions imblearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_iba_error_y_score_prob_error(score_loss):
aps(y_true, y_pred)


def test_classification_report_imbalanced_dict():
def test_classification_report_imbalanced_dict_with_target_names():
iris = datasets.load_iris()
y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)

Expand All @@ -471,12 +471,44 @@ def test_classification_report_imbalanced_dict():
output_dict=True,
)
outer_keys = set(report.keys())
inner_keys = set(report[0].keys())
inner_keys = set(report["setosa"].keys())

expected_outer_keys = {
0,
1,
2,
"setosa",
"versicolor",
"virginica",
"avg_pre",
"avg_rec",
"avg_spe",
"avg_f1",
"avg_geo",
"avg_iba",
"total_support",
}
expected_inner_keys = {"spe", "f1", "sup", "rec", "geo", "iba", "pre"}

assert outer_keys == expected_outer_keys
assert inner_keys == expected_inner_keys


def test_classification_report_imbalanced_dict_without_target_names():
iris = datasets.load_iris()
y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
print(iris.target_names)
report = classification_report_imbalanced(
y_true,
y_pred,
labels=np.arange(len(iris.target_names)),
output_dict=True,
)
print(report.keys())
outer_keys = set(report.keys())
inner_keys = set(report["0"].keys())

expected_outer_keys = {
"0",
"1",
"2",
"avg_pre",
"avg_rec",
"avg_spe",
Expand Down