Skip to content

Commit

Permalink
Sanitize redis db_statement by default (#1776)
Browse files Browse the repository at this point in the history
Co-authored-by: Srikanth Chekuri <[email protected]>
Co-authored-by: Shalev Roda <[email protected]>
  • Loading branch information
3 people authored Jun 13, 2023
1 parent 818ef43 commit 37d85f0
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 134 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix redis db.statements to be sanitized by default
([#1778](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1778))
- Fix elasticsearch db.statement attribute to be sanitized by default
([#1758](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1758))
- Fix `AttributeError` when AWS Lambda handler receives a list event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ async def redis_get():
response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request
this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None
sanitize_query (Boolean) - default False, enable the Redis query sanitization
for example:
.. code: python
Expand All @@ -88,37 +86,18 @@ def response_hook(span, instance, response):
client = redis.StrictRedis(host="localhost", port=6379)
client.get("my-key")
Configuration
-------------
Query sanitization
******************
To enable query sanitization with an environment variable, set
``OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS`` to "true".
For example,
::
export OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS="true"
will result in traced queries like "SET ? ?".
API
---
"""
import typing
from os import environ
from typing import Any, Collection

import redis
from wrapt import wrap_function_wrapper

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.redis.environment_variables import (
OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS,
)
from opentelemetry.instrumentation.redis.package import _instruments
from opentelemetry.instrumentation.redis.util import (
_extract_conn_attributes,
Expand Down Expand Up @@ -161,10 +140,9 @@ def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
sanitize_query: bool = False,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args, sanitize_query)
query = _format_command_args(args)

if len(args) > 0 and args[0]:
name = args[0]
Expand Down Expand Up @@ -194,7 +172,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs):

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
c.args if hasattr(c, "args") else c[0],
)
for c in command_stack
]
Expand Down Expand Up @@ -307,15 +285,6 @@ def _instrument(self, **kwargs):
tracer,
request_hook=kwargs.get("request_hook"),
response_hook=kwargs.get("response_hook"),
sanitize_query=kwargs.get(
"sanitize_query",
environ.get(
OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS, "false"
)
.lower()
.strip()
== "true",
),
)

def _uninstrument(self, **kwargs):
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,41 +48,23 @@ def _extract_conn_attributes(conn_kwargs):
return attributes


def _format_command_args(args, sanitize_query):
def _format_command_args(args):
"""Format and sanitize command arguments, and trim them as needed"""
cmd_max_len = 1000
value_too_long_mark = "..."
if sanitize_query:
# Sanitized query format: "COMMAND ? ?"
args_length = len(args)
if args_length > 0:
out = [str(args[0])] + ["?"] * (args_length - 1)
out_str = " ".join(out)

if len(out_str) > cmd_max_len:
out_str = (
out_str[: cmd_max_len - len(value_too_long_mark)]
+ value_too_long_mark
)
else:
out_str = ""
return out_str
# Sanitized query format: "COMMAND ? ?"
args_length = len(args)
if args_length > 0:
out = [str(args[0])] + ["?"] * (args_length - 1)
out_str = " ".join(out)

value_max_len = 100
length = 0
out = []
for arg in args:
cmd = str(arg)
if len(out_str) > cmd_max_len:
out_str = (
out_str[: cmd_max_len - len(value_too_long_mark)]
+ value_too_long_mark
)
else:
out_str = ""

if len(cmd) > value_max_len:
cmd = cmd[:value_max_len] + value_too_long_mark

if length + len(cmd) > cmd_max_len:
prefix = cmd[: cmd_max_len - length]
out.append(f"{prefix}{value_too_long_mark}")
break

out.append(cmd)
length += len(cmd)

return " ".join(out)
return out_str
Original file line number Diff line number Diff line change
Expand Up @@ -168,22 +168,11 @@ def test_query_sanitizer_enabled(self):
span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET ? ?")

def test_query_sanitizer_enabled_env(self):
def test_query_sanitizer(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

RedisInstrumentor().uninstrument()

env_patch = mock.patch.dict(
"os.environ",
{"OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS": "true"},
)
env_patch.start()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider,
)

with mock.patch.object(redis_client, "connection"):
redis_client.set("key", "value")

Expand All @@ -192,21 +181,6 @@ def test_query_sanitizer_enabled_env(self):

span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET ? ?")
env_patch.stop()

def test_query_sanitizer_disabled(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

with mock.patch.object(redis_client, "connection"):
redis_client.set("key", "value")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET key value")

def test_no_op_tracer_provider(self):
RedisInstrumentor().uninstrument()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def _check_span(self, span, name):

def test_long_command_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

self.redis_client.mget(*range(2000))

Expand All @@ -75,7 +73,7 @@ def test_long_command(self):
self._check_span(span, "MGET")
self.assertTrue(
span.attributes.get(SpanAttributes.DB_STATEMENT).startswith(
"MGET 0 1 2 3"
"MGET ? ? ? ?"
)
)
self.assertTrue(
Expand All @@ -84,9 +82,7 @@ def test_long_command(self):

def test_basics_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

self.assertIsNone(self.redis_client.get("cheese"))
spans = self.memory_exporter.get_finished_spans()
Expand All @@ -105,15 +101,13 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

def test_pipeline_traced_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

with self.redis_client.pipeline(transaction=False) as pipeline:
pipeline.set("blah", 32)
Expand Down Expand Up @@ -144,15 +138,13 @@ def test_pipeline_traced(self):
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

def test_pipeline_immediate_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

with self.redis_client.pipeline() as pipeline:
pipeline.set("a", 1)
Expand Down Expand Up @@ -182,7 +174,7 @@ def test_pipeline_immediate(self):
span = spans[0]
self._check_span(span, "SET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET b 2"
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?"
)

def test_parent(self):
Expand Down Expand Up @@ -230,7 +222,7 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

Expand All @@ -247,7 +239,7 @@ def test_pipeline_traced(self):
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

Expand Down Expand Up @@ -308,7 +300,7 @@ def test_long_command(self):
self._check_span(span, "MGET")
self.assertTrue(
span.attributes.get(SpanAttributes.DB_STATEMENT).startswith(
"MGET 0 1 2 3"
"MGET ? ? ? ?"
)
)
self.assertTrue(
Expand All @@ -322,7 +314,7 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

Expand All @@ -344,7 +336,7 @@ async def pipeline_simple():
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

Expand All @@ -364,7 +356,7 @@ async def pipeline_immediate():
span = spans[0]
self._check_span(span, "SET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET b 2"
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?"
)

def test_parent(self):
Expand Down Expand Up @@ -412,7 +404,7 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

Expand All @@ -434,7 +426,7 @@ async def pipeline_simple():
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

Expand Down Expand Up @@ -488,5 +480,5 @@ def test_get(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET foo"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)

0 comments on commit 37d85f0

Please sign in to comment.