Skip to content

Commit

Permalink
fix conversion error when a boolean feature has only one value
Browse files Browse the repository at this point in the history
  • Loading branch information
MainRo committed Mar 5, 2024
1 parent 958c102 commit 92eebb0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
8 changes: 6 additions & 2 deletions ebm2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
'str': onnx.TensorProto.STRING,
}

bool_remap = {
'False': '0',
'True': '1',
}

def infer_features_dtype(dtype, feature_name):
feature_dtype = onnx.TensorProto.DOUBLE
Expand Down Expand Up @@ -126,8 +130,8 @@ def to_onnx(model, dtype, name="ebm",
if feature_dtype == onnx.TensorProto.BOOL:
# ONNX converts booleans to strings 0/1, not False/True
col_mapping = {
'0': col_mapping['False'],
'1': col_mapping['True'],
bool_remap[k]: v
for k, v in col_mapping.items()
}
# replace inplace to re-use it in interactions
model.bins_[feature_group[0]][0] = col_mapping
Expand Down
8 changes: 5 additions & 3 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .utils import infer_model, create_session


def train_titanic_binary_classification(interactions=0, with_categorical=False):
def train_titanic_binary_classification(interactions=0, with_categorical=False, old_th=65):
df = pd.read_csv(
os.path.join('examples','titanic_train.csv'),
#dtype= {
Expand All @@ -22,7 +22,7 @@ def train_titanic_binary_classification(interactions=0, with_categorical=False):
#}
)
df = df.dropna()
df['Old'] = df['Age'] > 65
df['Old'] = df['Age'] > old_th
if with_categorical is False:
feature_types=['continuous', 'continuous', 'continuous', 'continuous']
feature_columns = ['Age', 'Fare', 'Pclass', 'Old']
Expand Down Expand Up @@ -168,10 +168,12 @@ def test_predict_regression_without_interactions(interactions, explain):

@pytest.mark.parametrize("explain", [False, True])
@pytest.mark.parametrize("interactions", [0, 2, [(0, 1, 2)], [(0, 1, 2, 3)]])
def test_predict_binary_classification_with_categorical(interactions, explain):
@pytest.mark.parametrize("old_th", [65, 0])
def test_predict_binary_classification_with_categorical(interactions, explain, old_th):
model_ebm, x_test, y_test = train_titanic_binary_classification(
interactions=interactions,
with_categorical=True,
old_th=old_th,
)
pred_ebm = model_ebm.predict(x_test)

Expand Down

0 comments on commit 92eebb0

Please sign in to comment.