Skip to content

Commit

Permalink
handle instrumentation of delegators
Browse files Browse the repository at this point in the history
  • Loading branch information
crflynn committed Oct 29, 2020
1 parent 9c55e99 commit d90763f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from sklearn.base import BaseEstimator
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.tree import BaseDecisionTree
from sklearn.utils.metaestimators import _IffHasAttrDescriptor

from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.sklearn.version import __version__
Expand Down Expand Up @@ -104,6 +105,55 @@ def wrapper(*args, **kwargs):
return wrapper


def implement_spans_delegator(obj: _IffHasAttrDescriptor):
"""Wrap the descriptor's fn with a span.
Args:
obj: An instance of _IffHasAttrDescriptor
"""
# Don't instrument inherited delegators
if hasattr(obj, "_otel_original_fn"):
logger.debug("Already instrumented: %s", obj.fn.__qualname__)
return

def implement_spans_get(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
with get_tracer(__name__, __version__).start_as_current_span(
name=func.__qualname__
):
return func(*args, **kwargs)

return wrapper

logger.debug("Instrumenting: %s", obj.fn.__qualname__)

setattr(obj, "_otel_original_fn", getattr(obj, "fn"))
setattr(obj, "fn", implement_spans_get(obj.fn))


def get_delegator(
estimator: Type[BaseEstimator], method_name: str
) -> Union[_IffHasAttrDescriptor, None]:
"""Get the delegator from a class method or None.
Args:
estimator (BaseEstimator): A class derived from ``sklearn``'s
``BaseEstimator``.
method_name (str): The method name of the estimator on which to
check for delegation.
Returns:
The delegator, if one exists, otherwise None.
"""
class_attr = getattr(estimator, method_name)
if getattr(class_attr, "__closure__", None) is not None:
for cell in class_attr.__closure__:
if isinstance(cell.cell_contents, _IffHasAttrDescriptor):
return cell.cell_contents
return None


def get_base_estimators(packages: List[str]) -> Dict[str, Type[BaseEstimator]]:
"""Walk package hierarchies to get BaseEstimator-derived classes.
Expand Down Expand Up @@ -389,7 +439,7 @@ def _check_instrumented(
method_name (str): The method name of the estimator on which to
check for instrumentation.
"""
orig_method_name = "_original_" + method_name
orig_method_name = "_otel_original_" + method_name
has_original = hasattr(estimator, orig_method_name)
orig_class, orig_method = getattr(
estimator, orig_method_name, (None, None)
Expand Down Expand Up @@ -419,11 +469,12 @@ def _uninstrument_class_method(
method_name (str): The method name of the estimator on which to
apply a span.
"""
orig_method_name = "_original_" + method_name
orig_method_name = "_otel_original_" + method_name
if isclass(estimator):
qualname = estimator.__qualname__
else:
qualname = estimator.__class__.__qualname__
delegator = get_delegator(estimator, method_name)
if self._check_instrumented(estimator, method_name):
logger.debug(
"Uninstrumenting: %s.%s", qualname, method_name,
Expand All @@ -433,6 +484,16 @@ def _uninstrument_class_method(
estimator, method_name, orig_method,
)
delattr(estimator, orig_method_name)
elif delegator is not None:
if not hasattr(delegator, "_otel_original_fn"):
logger.debug(
"Already uninstrumented: %s.%s", qualname, method_name,
)
return
setattr(
delegator, "fn", getattr(delegator, "_otel_original_fn"),
)
delattr(delegator, "_otel_original_fn")
else:
logger.debug(
"Already uninstrumented: %s.%s", qualname, method_name,
Expand All @@ -452,7 +513,7 @@ def _uninstrument_instance_method(
method_name (str): The method name of the estimator on which to
apply a span.
"""
orig_method_name = "_original_" + method_name
orig_method_name = "_otel_original_" + method_name
if isclass(estimator):
qualname = estimator.__qualname__
else:
Expand Down Expand Up @@ -496,37 +557,25 @@ def _instrument_class_method(
)
return
class_attr = getattr(estimator, method_name)
delegator = get_delegator(estimator, method_name)
if isinstance(class_attr, property):
logger.debug(
"Not instrumenting found property: %s.%s",
estimator.__qualname__,
method_name,
)
elif delegator is not None:
implement_spans_delegator(delegator)
else:
setattr(
estimator, "_original_" + method_name, (estimator, class_attr),
estimator,
"_otel_original_" + method_name,
(estimator, class_attr),
)
setattr(
estimator, method_name, self.spanner(class_attr, estimator),
)

def _function_wrapper(self, function):
"""Get the inner-most decorator of a function."""
if hasattr(function, "__wrapped__"):
if hasattr(function.__wrapped__, "__wrapped__"):
return self._function_wrapper(function.__wrapped__)
return function
return None

def _function_wrapper_wrapper(self, function):
"""Get the second inner-most decorator of a function"""
if hasattr(function, "__wrapped__"):
if hasattr(function.__wrapped__, "__wrapped__"):
if hasattr(function.__wrapped__.__wrapped__, "__wrapped__"):
return self._function_wrapper_wrapper(function.__wrapped__)
return function
return None

def _unwrap_function(self, function):
"""Fetch the function underlying any decorators"""
if hasattr(function, "__wrapped__"):
Expand Down Expand Up @@ -564,7 +613,9 @@ def _instrument_instance_method(
)
else:
method = getattr(estimator, method_name)
setattr(estimator, "_original_" + method_name, (estimator, method))
setattr(
estimator, "_otel_original_" + method_name, (estimator, method)
)
setattr(
estimator, method_name, self.spanner(method, estimator),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,54 @@
DEFAULT_METHODS,
SklearnInstrumentor,
get_base_estimators,
get_delegator,
)
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind

from .fixtures import pipeline, random_input


def assert_instrumented(base_estimators):
for _, estimator in base_estimators.items():
for method_name in DEFAULT_METHODS:
original_method_name = "_otel_original_" + method_name
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
assert not hasattr(estimator, original_method_name)
continue
class_attr = getattr(estimator, method_name, None)
if isinstance(class_attr, property):
assert not hasattr(estimator, original_method_name)
continue
delegator = None
if hasattr(estimator, method_name):
delegator = get_delegator(estimator, method_name)
if delegator is not None:
assert hasattr(delegator, "_otel_original_fn")
elif hasattr(estimator, method_name):
assert hasattr(estimator, original_method_name)


def assert_uninstrumented(base_estimators):
for _, estimator in base_estimators.items():
for method_name in DEFAULT_METHODS:
original_method_name = "_otel_original_" + method_name
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
assert not hasattr(estimator, original_method_name)
continue
class_attr = getattr(estimator, method_name, None)
if isinstance(class_attr, property):
assert not hasattr(estimator, original_method_name)
continue
delegator = None
if hasattr(estimator, method_name):
delegator = get_delegator(estimator, method_name)
if delegator is not None:
assert not hasattr(delegator, "_otel_original_fn")
elif hasattr(estimator, method_name):
assert not hasattr(estimator, original_method_name)


class TestSklearn(TestBase):
def test_package_instrumentation(self):
ski = SklearnInstrumentor()
Expand All @@ -21,42 +62,18 @@ def test_package_instrumentation(self):
model = pipeline()

ski.instrument()
# assert instrumented
for _, estimator in base_estimators.items():
for method_name in DEFAULT_METHODS:
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
assert not hasattr(estimator, "_original_" + method_name)
continue
class_attr = getattr(estimator, method_name, None)
if isinstance(class_attr, property):
assert not hasattr(estimator, "_original_" + method_name)
continue
if hasattr(estimator, method_name):
assert hasattr(estimator, "_original_" + method_name)
assert_instrumented(base_estimators)

x_test = random_input()

model.predict(x_test)

spans = self.memory_exporter.get_finished_spans()
for span in spans:
print(span)
self.assertEqual(len(spans), 8)
self.memory_exporter.clear()

ski.uninstrument()
# assert uninstrumented
for _, estimator in base_estimators.items():
for method_name in DEFAULT_METHODS:
if issubclass(estimator, tuple(DEFAULT_EXCLUDE_CLASSES)):
assert not hasattr(estimator, "_original_" + method_name)
continue
class_attr = getattr(estimator, method_name, None)
if isinstance(class_attr, property):
assert not hasattr(estimator, "_original_" + method_name)
continue
if hasattr(estimator, method_name):
assert not hasattr(estimator, "_original_" + method_name)
assert_uninstrumented(base_estimators)

model = pipeline()
x_test = random_input()
Expand Down

0 comments on commit d90763f

Please sign in to comment.