Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix async redis clients tracing #1830

Merged
merged 9 commits into from
Jun 25, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed
- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830))

## Version 1.18.0/0.39b0 (2023-05-10)

- `opentelemetry-instrumentation-system-metrics` Add `process.` prefix to `runtime.memory`, `runtime.cpu.time`, and `runtime.gc_count`. Change `runtime.memory` from count to UpDownCounter. ([#1735](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1735))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,44 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _build_span_name(instance, cmd_args):
if len(cmd_args) > 0 and cmd_args[0]:
name = cmd_args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
return name


def _build_span_meta_data_for_pipeline(instance, sanitize_query):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
)
for c in command_stack
]
resource = "\n".join(cmds)

span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""

return command_stack, resource, span_name


def _instrument(
tracer,
request_hook: _RequestHookT = None,
Expand All @@ -165,11 +203,8 @@ def _instrument(
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args, sanitize_query)
name = _build_span_name(instance, args)

if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
Expand All @@ -185,31 +220,11 @@ def _traced_execute_command(func, instance, args, kwargs):
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
)
for c in command_stack
]
resource = "\n".join(cmds)

span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance, sanitize_query)

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
Expand Down Expand Up @@ -254,32 +269,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
"ClusterPipeline.execute",
_traced_execute_pipeline,
)

async def _async_traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args, sanitize_query)
name = _build_span_name(instance, args)

with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
if callable(request_hook):
request_hook(span, instance, args, kwargs)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance, sanitize_query)

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(command_stack)
)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
wrap_function_wrapper(
"redis.asyncio",
f"{redis_class}.execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
wrap_function_wrapper(
"redis.asyncio.client",
f"{pipeline_class}.execute",
_traced_execute_pipeline,
_async_traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.asyncio.client",
f"{pipeline_class}.immediate_execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
wrap_function_wrapper(
"redis.asyncio.cluster",
"RedisCluster.execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
wrap_function_wrapper(
"redis.asyncio.cluster",
"ClusterPipeline.execute",
_traced_execute_pipeline,
_async_traced_execute_pipeline,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from unittest import mock

import redis
import redis.asyncio

from opentelemetry import trace
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind


class AsyncMock:
"""A sufficient async mock implementation.

Python 3.7 doesn't have an inbuilt async mock class, so this is used.
"""

def __init__(self):
self.mock = mock.Mock()

async def __call__(self, *args, **kwargs):
f = asyncio.Future()
f.set_result("random")
return f

def __getattr__(self, item):
return AsyncMock()


class TestRedis(TestBase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -87,6 +107,35 @@ def test_instrument_uninstrument(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_instrument_uninstrument_async_client_command(self):
redis_client = redis.asyncio.Redis()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

# Test uninstrument
RedisInstrumentor().uninstrument()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
self.memory_exporter.clear()

# Test instrument again
RedisInstrumentor().instrument()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_response_hook(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
Expand Down