Skip to content

Commit

Permalink
implements context propagation for lambda invoke + tests (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuba-wu authored May 14, 2021
1 parent c8103f5 commit c8ec25a
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.2.0-0.21b0...HEAD)

### Added
- `opentelemetry-instrumentation-botocore` now supports
context propagation for lambda invoke via Payload embedded headers.
([#458](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/458))

## [0.21b0](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.2.0-0.21b0) - 2021-05-11
### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
[options.extras_require]
test =
boto~=2.0
moto~=1.0
moto~=2.0
opentelemetry-test == 0.22.dev0

[options.packages.find]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ install_requires =

[options.extras_require]
test =
moto ~= 1.0
moto[all] ~= 2.0
opentelemetry-test == 0.22.dev0

[options.packages.find]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
---
"""

import json
import logging

from botocore.client import BaseClient
Expand Down Expand Up @@ -99,6 +100,27 @@ def _instrument(self, **kwargs):
def _uninstrument(self, **kwargs):
unwrap(BaseClient, "_make_api_call")

@staticmethod
def _is_lambda_invoke(service_name, operation_name, api_params):
return (
service_name == "lambda"
and operation_name == "Invoke"
and isinstance(api_params, dict)
and "Payload" in api_params
)

@staticmethod
def _patch_lambda_invoke(api_params):
try:
payload_str = api_params["Payload"]
payload = json.loads(payload_str)
headers = payload.get("headers", {})
inject(headers)
payload["headers"] = headers
api_params["Payload"] = json.dumps(payload)
except ValueError:
pass

# pylint: disable=too-many-branches
def _patched_api_call(self, original_func, instance, args, kwargs):
if context_api.get_value("suppress_instrumentation"):
Expand All @@ -111,6 +133,12 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
error = None
result = None

# inject trace context into payload headers for lambda Invoke
if BotocoreInstrumentor._is_lambda_invoke(
service_name, operation_name, api_params
):
BotocoreInstrumentor._patch_lambda_invoke(api_params)

with self._tracer.start_as_current_span(
"{}".format(service_name), kind=SpanKind.CLIENT,
) as span:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import zipfile
from unittest.mock import Mock, patch

import botocore.session
from botocore.exceptions import ParamValidationError
from moto import ( # pylint: disable=import-error
mock_dynamodb2,
mock_ec2,
mock_iam,
mock_kinesis,
mock_kms,
mock_lambda,
Expand All @@ -37,6 +40,24 @@
from opentelemetry.test.test_base import TestBase


def get_as_zip_file(file_name, content):
zip_output = io.BytesIO()
with zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) as zip_file:
zip_file.writestr(file_name, content)
zip_output.seek(0)
return zip_output.read()


def return_headers_lambda_str():
pfunc = """
def lambda_handler(event, context):
print("custom log event")
headers = event.get('headers', event.get('attributes', {}))
return headers
"""
return pfunc


class TestBotocoreInstrumentor(TestBase):
"""Botocore integration testsuite"""

Expand Down Expand Up @@ -328,6 +349,64 @@ def test_lambda_client(self):
},
)

@mock_iam
def get_role_name(self):
iam = self.session.create_client("iam", "us-east-1")
return iam.create_role(
RoleName="my-role",
AssumeRolePolicyDocument="some policy",
Path="/my-path/",
)["Role"]["Arn"]

@mock_lambda
def test_lambda_invoke_propagation(self):

previous_propagator = get_global_textmap()
try:
set_global_textmap(MockTextMapPropagator())

lamb = self.session.create_client(
"lambda", region_name="us-east-1"
)
lamb.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=self.get_role_name(),
Handler="lambda_function.lambda_handler",
Code={
"ZipFile": get_as_zip_file(
"lambda_function.py", return_headers_lambda_str()
)
},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
response = lamb.invoke(
Payload=json.dumps({}),
FunctionName="testFunction",
InvocationType="RequestResponse",
)

spans = self.memory_exporter.get_finished_spans()
assert spans
self.assertEqual(len(spans), 3)

results = response["Payload"].read().decode("utf-8")
headers = json.loads(results)

self.assertIn(MockTextMapPropagator.TRACE_ID_KEY, headers)
self.assertEqual(
"0", headers[MockTextMapPropagator.TRACE_ID_KEY],
)
self.assertIn(MockTextMapPropagator.SPAN_ID_KEY, headers)
self.assertEqual(
"0", headers[MockTextMapPropagator.SPAN_ID_KEY],
)
finally:
set_global_textmap(previous_propagator)

@mock_kms
def test_kms_client(self):
kms = self.session.create_client("kms", region_name="us-east-1")
Expand Down

0 comments on commit c8ec25a

Please sign in to comment.