diff --git a/CHANGELOG.md b/CHANGELOG.md index f779f3f109..f60d4041f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#1507](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1507)) - Fix pymongo to collect the property DB_MONGODB_COLLECTION ([#1555](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1555)) +- `opentelemetry-instrumentation-asgi` Fix keys() in class ASGIGetter to correctly fetch values from carrier headers. + ([#1435](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1435)) + ## Version 1.15.0/0.36b0 (2022-12-10) diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 55bb418647..083fe771d8 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -260,7 +260,8 @@ def get( return decoded def keys(self, carrier: dict) -> typing.List[str]: - return [_key.decode("utf8") for (_key, _value) in carrier] + headers = carrier.get("headers") or [] + return [_key.decode("utf8") for (_key, _value) in headers] asgi_getter = ASGIGetter() diff --git a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_getter.py b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_getter.py index 454162d715..26bb652b50 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_getter.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_getter.py @@ -18,12 +18,18 @@ class TestASGIGetter(TestCase): - def test_get_none(self): + def test_get_none_empty_carrier(self): getter = ASGIGetter() carrier = {} val = getter.get(carrier, "test") self.assertIsNone(val) + def test_get_none_empty_headers(self): + getter = ASGIGetter() + carrier = {"headers": []} + val = getter.get(carrier, "test") + self.assertIsNone(val) + def test_get_(self): getter = ASGIGetter() carrier = {"headers": [(b"test-key", b"val")]} @@ -44,7 +50,22 @@ def test_get_(self): "Should be case insensitive", ) - def test_keys(self): + def test_keys_empty_carrier(self): getter = ASGIGetter() keys = getter.keys({}) self.assertEqual(keys, []) + + def test_keys_empty_headers(self): + getter = ASGIGetter() + keys = getter.keys({"headers": []}) + self.assertEqual(keys, []) + + def test_keys(self): + getter = ASGIGetter() + carrier = {"headers": [(b"test-key", b"val")]} + expected_val = ["test-key"] + self.assertEqual( + getter.keys(carrier), + expected_val, + "Should be equal", + )