Skip to content

Commit

Permalink
Consistent way of not instrumenting multiple times (#549)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzchen authored Jul 9, 2021
1 parent bf97e17 commit 56da6d7
Show file tree
Hide file tree
Showing 17 changed files with 236 additions and 82 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#538](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/538))
- Changed the psycopg2-binary to psycopg2 as dependency in production
([#543](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/543))
- Implement consistent way of checking if instrumentation is already active
([#549](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/549))
- Require aiopg to be less than 1.3.0
([#560](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/560))
- `opentelemetry-instrumentation-django` Migrated Django middleware to new-style.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def instrumented_init(wrapped, instance, args, kwargs):
span_name=span_name,
tracer_provider=tracer_provider,
)
trace_config.opentelemetry_aiohttp_instrumented = True
trace_config._is_instrumented_by_opentelemetry = True
trace_configs.append(trace_config)

kwargs["trace_configs"] = trace_configs
Expand All @@ -282,7 +282,7 @@ def _uninstrument_session(client_session: aiohttp.ClientSession):
client_session._trace_configs = [
trace_config
for trace_config in trace_configs
if not hasattr(trace_config, "opentelemetry_aiohttp_instrumented")
if not hasattr(trace_config, "_is_instrumented_by_opentelemetry")
]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import typing

import wrapt
from aiopg.utils import ( # pylint: disable=no-name-in-module
_ContextManager,
_PoolAcquireContextManager,
)
from aiopg.utils import _ContextManager, _PoolAcquireContextManager

from opentelemetry.instrumentation.dbapi import (
CursorTracer,
Expand Down Expand Up @@ -64,9 +61,7 @@ def __init__(self, connection, *args, **kwargs):

def cursor(self, *args, **kwargs):
coro = self._cursor(*args, **kwargs)
return _ContextManager( # pylint: disable=no-value-for-parameter
coro
)
return _ContextManager(coro)

async def _cursor(self, *args, **kwargs):
# pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from opentelemetry.instrumentation.aiopg.aiopg_integration import (
AiopgIntegration,
AsyncProxyObject,
get_traced_connection_proxy,
)
from opentelemetry.instrumentation.aiopg.version import __version__
Expand Down Expand Up @@ -153,6 +154,10 @@ def instrument_connection(
Returns:
An instrumented connection.
"""
if isinstance(connection, AsyncProxyObject):
logger.warning("Connection already instrumented")
return connection

db_integration = AiopgIntegration(
name,
database_system,
Expand All @@ -173,7 +178,7 @@ def uninstrument_connection(connection):
Returns:
An uninstrumented connection.
"""
if isinstance(connection, wrapt.ObjectProxy):
if isinstance(connection, AsyncProxyObject):
return connection.__wrapped__

logger.warning("Connection is not instrumented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ def test_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrument_connection_after_instrument(self):
cnx = async_call(aiopg.connect(database="test"))
query = "SELECT * FROM test"
cursor = async_call(cnx.cursor())
async_call(cursor.execute(query))

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 0)

AiopgInstrumentor().instrument()
cnx = AiopgInstrumentor().instrument_connection(cnx)
cursor = async_call(cnx.cursor())
async_call(cursor.execute(query))

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_custom_tracer_provider_instrument_connection(self):
resource = resources.Resource.create(
{"service.name": "db-test-service"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def instrument_connection(
Returns:
An instrumented connection.
"""
if isinstance(connection, wrapt.ObjectProxy):
logger.warning("Connection already instrumented")
return connection

db_integration = DatabaseApiIntegration(
name,
database_system,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Collection

import fastapi
from starlette import middleware
from starlette.routing import Match

from opentelemetry.instrumentation.asgi import OpenTelemetryMiddleware
Expand All @@ -24,6 +26,7 @@
from opentelemetry.util.http import get_excluded_urls, parse_excluded_urls

_excluded_urls_from_env = get_excluded_urls("FASTAPI")
_logger = logging.getLogger(__name__)


class FastAPIInstrumentor(BaseInstrumentor):
Expand All @@ -39,7 +42,10 @@ def instrument_app(
app: fastapi.FastAPI, tracer_provider=None, excluded_urls=None,
):
"""Instrument an uninstrumented FastAPI application."""
if not getattr(app, "is_instrumented_by_opentelemetry", False):
if not hasattr(app, "_is_instrumented_by_opentelemetry"):
app._is_instrumented_by_opentelemetry = False

if not getattr(app, "_is_instrumented_by_opentelemetry", False):
if excluded_urls is None:
excluded_urls = _excluded_urls_from_env
else:
Expand All @@ -51,7 +57,21 @@ def instrument_app(
span_details_callback=_get_route_details,
tracer_provider=tracer_provider,
)
app.is_instrumented_by_opentelemetry = True
app._is_instrumented_by_opentelemetry = True
else:
_logger.warning(
"Attempting to instrument FastAPI app while already instrumented"
)

@staticmethod
def uninstrument_app(app: fastapi.FastAPI):
app.user_middleware = [
x
for x in app.user_middleware
if x.cls is not OpenTelemetryMiddleware
]
app.middleware_stack = app.build_middleware_stack()
app._is_instrumented_by_opentelemetry = False

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from fastapi.testclient import TestClient

import opentelemetry.instrumentation.fastapi as otel_fastapi
from opentelemetry.instrumentation.asgi import OpenTelemetryMiddleware
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
Expand Down Expand Up @@ -57,6 +58,47 @@ def tearDown(self):
super().tearDown()
self.env_patch.stop()
self.exclude_patch.stop()
with self.disable_logging():
self._instrumentor.uninstrument()
self._instrumentor.uninstrument_app(self._app)

def test_instrument_app_with_instrument(self):
if not isinstance(self, TestAutoInstrumentation):
self._instrumentor.instrument()
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
for span in spans:
self.assertIn("/foobar", span.name)

def test_uninstrument_app(self):
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
# pylint: disable=import-outside-toplevel
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

self._app.add_middleware(HTTPSRedirectMiddleware)
self._instrumentor.uninstrument_app(self._app)
print(self._app.user_middleware[0].cls)
self.assertFalse(
isinstance(
self._app.user_middleware[0].cls, OpenTelemetryMiddleware
)
)
self._client = TestClient(self._app)
resp = self._client.get("/foobar")
self.assertEqual(200, resp.status_code)
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 3)

def test_uninstrument_app_after_instrument(self):
if not isinstance(self, TestAutoInstrumentation):
self._instrumentor.instrument()
self._instrumentor.uninstrument_app(self._app)
self._client.get("/foobar")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)

def test_basic_fastapi_call(self):
self._client.get("/foobar")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class _InstrumentedFlask(flask.Flask):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._original_wsgi_ = self.wsgi_app
self._original_wsgi_app = self.wsgi_app
self._is_instrumented_by_opentelemetry = True

self.wsgi_app = _rewrapped_app(
self.wsgi_app, _InstrumentedFlask._response_hook
Expand Down Expand Up @@ -229,18 +230,21 @@ def _instrument(self, **kwargs):
_InstrumentedFlask._request_hook = request_hook
if callable(response_hook):
_InstrumentedFlask._response_hook = response_hook
flask.Flask = _InstrumentedFlask
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedFlask._tracer_provider = tracer_provider
flask.Flask = _InstrumentedFlask

def _uninstrument(self, **kwargs):
flask.Flask = self._original_flask

@staticmethod
def instrument_app(
self, app, request_hook=None, response_hook=None, tracer_provider=None
): # pylint: disable=no-self-use
if not hasattr(app, "_is_instrumented"):
app._is_instrumented = False
app, request_hook=None, response_hook=None, tracer_provider=None
):
if not hasattr(app, "_is_instrumented_by_opentelemetry"):
app._is_instrumented_by_opentelemetry = False

if not app._is_instrumented:
if not app._is_instrumented_by_opentelemetry:
app._original_wsgi_app = app.wsgi_app
app.wsgi_app = _rewrapped_app(app.wsgi_app, response_hook)

Expand All @@ -250,28 +254,22 @@ def instrument_app(
app._before_request = _before_request
app.before_request(_before_request)
app.teardown_request(_teardown_request)
app._is_instrumented = True
app._is_instrumented_by_opentelemetry = True
else:
_logger.warning(
"Attempting to instrument Flask app while already instrumented"
)

def _uninstrument(self, **kwargs):
flask.Flask = self._original_flask

def uninstrument_app(self, app): # pylint: disable=no-self-use
if not hasattr(app, "_is_instrumented"):
app._is_instrumented = False

if app._is_instrumented:
@staticmethod
def uninstrument_app(app):
if hasattr(app, "_original_wsgi_app"):
app.wsgi_app = app._original_wsgi_app

# FIXME add support for other Flask blueprints that are not None
app.before_request_funcs[None].remove(app._before_request)
app.teardown_request_funcs[None].remove(_teardown_request)
del app._original_wsgi_app

app._is_instrumented = False
app._is_instrumented_by_opentelemetry = False
else:
_logger.warning(
"Attempting to uninstrument Flask "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,16 @@ def tearDown(self):
with self.disable_logging():
FlaskInstrumentor().uninstrument_app(self.app)

def test_uninstrument(self):
def test_instrument_app_and_instrument(self):
FlaskInstrumentor().instrument()
resp = self.client.get("/hello/123")
self.assertEqual(200, resp.status_code)
self.assertEqual([b"Hello: 123"], list(resp.response))
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
FlaskInstrumentor().uninstrument()

def test_uninstrument_app(self):
resp = self.client.get("/hello/123")
self.assertEqual(200, resp.status_code)
self.assertEqual([b"Hello: 123"], list(resp.response))
Expand All @@ -94,6 +103,16 @@ def test_uninstrument(self):
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)

def test_uninstrument_app_after_instrument(self):
FlaskInstrumentor().instrument()
FlaskInstrumentor().uninstrument_app(self.app)
resp = self.client.get("/hello/123")
self.assertEqual(200, resp.status_code)
self.assertEqual([b"Hello: 123"], list(resp.response))
span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 0)
FlaskInstrumentor().uninstrument()

# pylint: disable=no-member
def test_only_strings_in_environ(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
---
"""

import logging
import typing
from typing import Collection

Expand All @@ -53,6 +54,7 @@
from opentelemetry.instrumentation.psycopg2.package import _instruments
from opentelemetry.instrumentation.psycopg2.version import __version__

_logger = logging.getLogger(__name__)
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"


Expand Down Expand Up @@ -91,24 +93,32 @@ def _uninstrument(self, **kwargs):
dbapi.unwrap_connect(psycopg2, "connect")

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
def instrument_connection(
self, connection, tracer_provider=None
): # pylint: disable=no-self-use
setattr(
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
)
connection.cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
@staticmethod
def instrument_connection(connection, tracer_provider=None):
if not hasattr(connection, "_is_instrumented_by_opentelemetry"):
connection._is_instrumented_by_opentelemetry = False

if not connection._is_instrumented_by_opentelemetry:
setattr(
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
)
connection.cursor_factory = _new_cursor_factory(
tracer_provider=tracer_provider
)
connection._is_instrumented_by_opentelemetry = True
else:
_logger.warning(
"Attempting to instrument Psycopg connection while already instrumented"
)
return connection

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
def uninstrument_connection(
self, connection
): # pylint: disable=no-self-use
@staticmethod
def uninstrument_connection(connection):
connection.cursor_factory = getattr(
connection, _OTEL_CURSOR_FACTORY_KEY, None
)

return connection


Expand Down
Loading

0 comments on commit 56da6d7

Please sign in to comment.