From 0cf3f2f3c4ea7b4cf498e732295defdb712e4ad5 Mon Sep 17 00:00:00 2001 From: Paul Koch Date: Wed, 28 Feb 2024 23:00:22 -0800 Subject: [PATCH] fix bug in convert_categorical_to_continuous that did not handle scenarios where there is only one continuous section --- .../interpret/glassbox/_ebm/_utils.py | 4 ++-- .../tests/glassbox/ebm/test_ebm_utils.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/interpret-core/interpret/glassbox/_ebm/_utils.py b/python/interpret-core/interpret/glassbox/_ebm/_utils.py index a6ebd46b7..232fa9fa2 100644 --- a/python/interpret-core/interpret/glassbox/_ebm/_utils.py +++ b/python/interpret-core/interpret/glassbox/_ebm/_utils.py @@ -90,8 +90,8 @@ def convert_categorical_to_continuous(categories): non_float_idxs = [idx for idx in non_float_idxs if idx not in clusters] non_float_idxs.append(max(categories.values()) + 1) - if len(clusters) <= 1: - return np.empty(0, np.float64) + if len(clusters) == 0: + return np.empty(0, np.float64), [[0], [], non_float_idxs], np.nan, np.nan cluster_bounds = sorted((min(cluster_list), max(cluster_list)) for cluster_list in clusters.values()) diff --git a/python/interpret-core/tests/glassbox/ebm/test_ebm_utils.py b/python/interpret-core/tests/glassbox/ebm/test_ebm_utils.py index c79a10742..0a79f0ec7 100644 --- a/python/interpret-core/tests/glassbox/ebm/test_ebm_utils.py +++ b/python/interpret-core/tests/glassbox/ebm/test_ebm_utils.py @@ -60,6 +60,27 @@ def test_make_bag_stratified(): bag = make_bag(y, test_size=0.25, rng=1, is_stratified=True) +def test_convert_categorical_to_continuous_none(): + cuts, mapping, old_min, old_max = convert_categorical_to_continuous( + {"ABCD": 1, "EFGH": 2, "IJKL": 1} + ) + assert len(cuts) == 0 + assert mapping == [[0], [], [1, 2, 3]] + assert np.isnan(old_min) + assert np.isnan(old_max) + + + +def test_convert_categorical_to_continuous_single(): + cuts, mapping, old_min, old_max = convert_categorical_to_continuous( + {"10": 1, "+10": 2, "30": 1} + ) + assert len(cuts) == 0 + assert mapping == [[0], [1, 2], [3]] + assert old_min == 10 + assert old_max == 30 + + def test_convert_categorical_to_continuous_easy(): cuts, mapping, old_min, old_max = convert_categorical_to_continuous( {"10": 1, "20": 2, "30": 3}