Skip to content

Commit

Permalink
fix predictions with boolean features (#12)
Browse files Browse the repository at this point in the history
fixes #11

ONNX converts boolean to strings as 0 and 1 while the EBM python
implementation expects False/True. To fix this, we replace the
bins keys with values returned by ONNX.

Signed-off-by: Romain Picard <[email protected]>
  • Loading branch information
MainRo authored Mar 4, 2024
1 parent 96d4345 commit 669c7ee
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 26 deletions.
8 changes: 8 additions & 0 deletions ebm2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def to_onnx(model, dtype, name="ebm",

feature_dtype = infer_features_dtype(dtype, feature_name)
part = graph.create_input(root, feature_name, feature_dtype, [None])
if feature_dtype == onnx.TensorProto.BOOL:
# ONNX converts booleans to strings 0/1, not False/True
col_mapping = {
'0': col_mapping['False'],
'1': col_mapping['True'],
}
# replace inplace to re-use it in interactions
model.bins_[feature_group[0]][0] = col_mapping
if feature_dtype != onnx.TensorProto.STRING:
part = ops.cast(onnx.TensorProto.STRING)(part)
part = ops.flatten()(part)
Expand Down
2 changes: 1 addition & 1 deletion ebm2onnx/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _argmax(g):


def cast(to):
def _cast(g):
def _cast(g):
cast_result_name = g.generate_name('cast_result')
nodes = [
onnx.helper.make_node("Cast", [g.transients[0].name], [cast_result_name], to=to),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train_titanic_binary_classification(interactions=0, with_categorical=False):
feature_types=['continuous', 'continuous', 'continuous', 'continuous']
feature_columns = ['Age', 'Fare', 'Pclass', 'Old']
else:
feature_types=['continuous', 'continuous', 'nominal', 'continuous', 'nominal']
feature_types=['continuous', 'continuous', 'nominal', 'nominal', 'nominal']
feature_columns = ['Age', 'Fare', 'Pclass', 'Old', 'Embarked']
label_column = "Survived"

Expand Down
66 changes: 44 additions & 22 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import ebm2onnx.graph as graph
import ebm2onnx.operators as ops

Expand All @@ -24,29 +25,50 @@ def test_add():
)


def test_cast():
@pytest.mark.parametrize(
"from_type,to_type,input,output",
[
pytest.param(
onnx.TensorProto.INT64,
onnx.TensorProto.FLOAT,
{'i': [[1], [2], [11], [4]]},
[[[1.0], [2.0], [11.0], [4.0]]],
id='int64_to_float'
),
pytest.param(
onnx.TensorProto.INT64,
onnx.TensorProto.STRING,
{'i': [[1], [2], [11], [4]]},
[[["1"], ["2"], ["11"], ["4"]]],
id='int64_to_string'
),
pytest.param(
onnx.TensorProto.BOOL,
onnx.TensorProto.UINT8,
{'i': [[False], [True]]},
[[[0], [1]]],
id='bool_to_uint8'
),
pytest.param(
onnx.TensorProto.BOOL,
onnx.TensorProto.STRING,
{'i': [[False], [True]]},
[[["0"], ["1"]]],
id='bool_to_string'
),
]
)
def test_cast(from_type, to_type, input, output):
g = graph.create_graph()

i = graph.create_input(g, "i", onnx.TensorProto.INT64, [None, 1])

l = ops.cast(onnx.TensorProto.FLOAT)(i)
l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT, [None, 1])

assert_model_result(l,
input={
'i': [
[1],
[2],
[11],
[4],
]
},
expected_result=[[
[1.0],
[2.0],
[11.0],
[4.0]
]]
i = graph.create_input(g, "i", from_type, [None, 1])
l = ops.cast(to_type)(i)
l = graph.add_output(l, l.transients[0].name, to_type, [None, 1])

assert_model_result(
l,
input=input,
expected_result=output,
exact_match=to_type in [onnx.TensorProto.INT64, onnx.TensorProto.STRING]
)


Expand Down
14 changes: 12 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def infer_model(model, input):
os.unlink(filename)


def assert_model_result(g, input, expected_result, atol=1e-08, save_path=None):
def assert_model_result(
g, input,
expected_result,
exact_match=False,
atol=1e-08,
save_path=None
):
model = graph.compile(g, target_opset=13)
_, filename = tempfile.mkstemp()
try:
Expand All @@ -45,8 +51,12 @@ def assert_model_result(g, input, expected_result, atol=1e-08, save_path=None):
pred = sess.run(None, input)

print(pred)
print(expected_result)
for i, p in enumerate(pred):
assert np.allclose(p, np.array(expected_result[i]))
if exact_match:
assert p.tolist() == expected_result[i]
else:
assert np.allclose(p, np.array(expected_result[i]))

finally:
os.unlink(filename)

0 comments on commit 669c7ee

Please sign in to comment.