Skip to content

Commit

Permalink
Include training step in metric scorer name (#712)
Browse files Browse the repository at this point in the history
* Include training step in scorer name

* Add fit_predict data proxying

* Remove name comments

* Fix predict being called before fit

* Re-use existing fixture
  • Loading branch information
hmstepanek authored Dec 16, 2022
1 parent c6a9d4c commit d9d5636
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
22 changes: 17 additions & 5 deletions newrelic/hooks/mlmodel_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@


class PredictReturnTypeProxy(ObjectProxy):
def __init__(self, wrapped, model_name):
def __init__(self, wrapped, model_name, training_step):
super(ObjectProxy, self).__init__(wrapped)
self._nr_model_name = model_name
self._nr_training_step = training_step


def _wrap_method_trace(module, _class, method, name=None, group=None):
Expand Down Expand Up @@ -65,10 +66,16 @@ def _nr_wrapper_method(wrapped, instance, args, kwargs):
# Set the _nr_wrapped attribute to denote that this method is no longer wrapped.
setattr(trace, wrapped_attr_name, False)

# If this is the fit method, increment the training_step counter.
if method in ("fit", "fit_predict"):
training_step = getattr(instance, "_nr_wrapped_training_step", -1)
setattr(instance, "_nr_wrapped_training_step", training_step + 1)

# If this is the predict method, wrap the return type in an nr type with
# _nr_wrapped attrs that will attach model info to the data.
if method == "predict":
return PredictReturnTypeProxy(return_val, model_name=_class)
if method in ("predict", "fit_predict"):
training_step = getattr(instance, "_nr_wrapped_training_step", "Unknown")
return PredictReturnTypeProxy(return_val, model_name=_class, training_step=training_step)
return return_val

wrap_function_wrapper(module, "%s.%s" % (_class, method), _nr_wrapper_method)
Expand Down Expand Up @@ -102,16 +109,21 @@ def wrap_metric_scorer(wrapped, instance, args, kwargs):

y_true, y_pred, args, kwargs = _bind_scorer(*args, **kwargs)
model_name = "Unknown"
training_step = "Unknown"
if hasattr(y_pred, "_nr_model_name"):
model_name = y_pred._nr_model_name
if hasattr(y_pred, "_nr_training_step"):
training_step = y_pred._nr_training_step
# Attribute values must be int, float, str, or boolean. If it's not one of these
# types and an iterable add the values as separate attributes.
if not isinstance(score, (str, int, float, bool)):
if hasattr(score, "__iter__"):
for i, s in enumerate(score):
transaction._add_agent_attribute("%s.%s[%s]" % (model_name, wrapped.__name__, i), s)
transaction._add_agent_attribute(
"%s/TrainingStep/%s/%s[%s]" % (model_name, training_step, wrapped.__name__, i), s
)
else:
transaction._add_agent_attribute("%s.%s" % (model_name, wrapped.__name__), score)
transaction._add_agent_attribute("%s/TrainingStep/%s/%s" % (model_name, training_step, wrapped.__name__), score)
return score


Expand Down
52 changes: 42 additions & 10 deletions tests/mlmodel_sklearn/test_metric_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,41 @@
),
)
def test_metric_scorer_attributes(metric_scorer_name, run_metric_scorer):
@validate_attributes("agent", ["DecisionTreeClassifier.%s" % metric_scorer_name])
@validate_attributes("agent", ["DecisionTreeClassifier/TrainingStep/0/%s" % metric_scorer_name])
@background_task()
def _test():
run_metric_scorer(metric_scorer_name)

_test()


@pytest.mark.parametrize(
"metric_scorer_name",
(
"accuracy_score",
"balanced_accuracy_score",
"f1_score",
"precision_score",
"recall_score",
"roc_auc_score",
"r2_score",
),
)
def test_metric_scorer_training_steps_attributes(metric_scorer_name, run_metric_scorer):
@validate_attributes(
"agent",
[
"DecisionTreeClassifier/TrainingStep/0/%s" % metric_scorer_name,
"DecisionTreeClassifier/TrainingStep/1/%s" % metric_scorer_name,
],
)
@background_task()
def _test():
run_metric_scorer(metric_scorer_name, training_steps=[0, 1])

_test()


@pytest.mark.parametrize(
"metric_scorer_name,kwargs",
[
Expand All @@ -53,8 +80,8 @@ def test_metric_scorer_iterable_score_attributes(metric_scorer_name, kwargs, run
@validate_attributes(
"agent",
[
"DecisionTreeClassifier.%s[0]" % metric_scorer_name,
"DecisionTreeClassifier.%s[1]" % metric_scorer_name,
"DecisionTreeClassifier/TrainingStep/0/%s[0]" % metric_scorer_name,
"DecisionTreeClassifier/TrainingStep/0/%s[1]" % metric_scorer_name,
],
)
@background_task()
Expand All @@ -77,7 +104,7 @@ def _test():
],
)
def test_metric_scorer_attributes_unknown_model(metric_scorer_name):
@validate_attributes("agent", ["Unknown.%s" % metric_scorer_name])
@validate_attributes("agent", ["Unknown/TrainingStep/Unknown/%s" % metric_scorer_name])
@background_task()
def _test():
from sklearn import metrics
Expand All @@ -92,27 +119,32 @@ def _test():

@pytest.mark.parametrize("data", (np.array([0, 1]), "foo", 1, 1.0, True, [0, 1], {"foo": "bar"}, (0, 1), np.str_("F")))
def test_PredictReturnTypeProxy(data):
wrapped_data = PredictReturnTypeProxy(data, "ModelName")
wrapped_data = PredictReturnTypeProxy(data, "ModelName", 0)

assert wrapped_data._nr_model_name == "ModelName"
assert wrapped_data._nr_training_step == 0


@pytest.fixture
def run_metric_scorer():
def _run(metric_scorer_name, metric_scorer_kwargs=None):
def _run(metric_scorer_name, metric_scorer_kwargs=None, training_steps=None):
from sklearn import metrics, tree

x_train = [[0, 0], [1, 1]]
y_train = [0, 1]
x_test = [[2.0, 2.0], [0, 0.5]]
y_test = [1, 0]

if not training_steps:
training_steps = [0]

clf = tree.DecisionTreeClassifier(random_state=0)
model = clf.fit(x_train, y_train)
for step in training_steps:
model = clf.fit(x_train, y_train)

labels = model.predict(x_test)
labels = model.predict(x_test)

metric_scorer_kwargs = metric_scorer_kwargs or {}
return getattr(metrics, metric_scorer_name)(y_test, labels, **metric_scorer_kwargs)
metric_scorer_kwargs = metric_scorer_kwargs or {}
getattr(metrics, metric_scorer_name)(y_test, labels, **metric_scorer_kwargs)

return _run

0 comments on commit d9d5636

Please sign in to comment.