Skip to content

Commit

Permalink
fix output names for scores and predict_proba
Browse files Browse the repository at this point in the history
Fixes #3
  • Loading branch information
MainRo committed Aug 2, 2022
1 parent e484814 commit d876d2d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
3 changes: 2 additions & 1 deletion ebm2onnx/ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _compute_class_score(g):
)

g = ops.concat(axis=1)(g)
g = ops.identity("scores")(g)
scores_output_name = g.transients[0].name
g = ops.reduce_sum(keepdims=0)(graph.merge(g, init_sum_axis))
g = ops.add()(graph.merge(g, init_intercept))
Expand Down Expand Up @@ -152,7 +153,7 @@ def _predict_proba(g):

g = ops.mul()(graph.merge(g, init_zeros))
g = ops.softmax(axis=1)(g)
g = ops.identity("predict")(g)
g = ops.identity("predict_proba")(g)
return g

return _predict_proba
Expand Down
53 changes: 52 additions & 1 deletion tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import onnx
import ebm2onnx
from .utils import infer_model
from .utils import infer_model, create_session


def train_titanic_binary_classification(interactions, with_categorical=False):
Expand Down Expand Up @@ -330,3 +330,54 @@ def test_predict_proba_multiclass_classification():

assert len(pred_onnx) == 2
assert np.allclose(pred_ebm, pred_onnx[0])


def test_predict_w_scores_outputs_def():
model_ebm, _, _ = train_titanic_binary_classification(interactions=0)

model_onnx = ebm2onnx.to_onnx(
model_ebm,
explain=True,
dtype={
'Age': 'double',
'Fare': 'double',
'Pclass': 'int',
'Old': 'bool',
}
)
session = create_session(model_onnx)

outputs = session.get_outputs()
assert len(outputs) == 2
assert outputs[0].name == "predict_0"
assert outputs[0].shape == [None]
assert outputs[0].type == 'tensor(int64)'
assert outputs[1].name == "scores_0"
assert outputs[1].shape == [None, 4, 1]
assert outputs[1].type == 'tensor(float)'


def test_predict_proba_w_scores_outputs_def():
model_ebm, _, _ = train_titanic_binary_classification(interactions=0)

model_onnx = ebm2onnx.to_onnx(
model_ebm,
predict_proba=True,
explain=True,
dtype={
'Age': 'double',
'Fare': 'double',
'Pclass': 'int',
'Old': 'bool',
}
)
session = create_session(model_onnx)

outputs = session.get_outputs()
assert len(outputs) == 2
assert outputs[0].name == "predict_proba_0"
assert outputs[0].shape == [None, 2]
assert outputs[0].type == 'tensor(float)'
assert outputs[1].name == "scores_0"
assert outputs[1].shape == [None, 4, 1]
assert outputs[1].type == 'tensor(float)'
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
import ebm2onnx.graph as graph


def create_session(model):
_, filename = tempfile.mkstemp()
try:
onnx.save_model(model, filename)
sess = rt.InferenceSession(filename)
return sess
finally:
os.unlink(filename)


def infer_model(model, input):
_, filename = tempfile.mkstemp()
try:
Expand Down

0 comments on commit d876d2d

Please sign in to comment.