Skip to content

Commit f583901

Browse files
authored
we dont need thread locks (#6551)
1 parent f9f2db1 commit f583901

File tree

2 files changed

+13
-36
lines changed

2 files changed

+13
-36
lines changed

sdk/core/azure-core/azure/core/tracing/context.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class ContextProtocol(Protocol):
2929
Implements set and get variables in a thread safe way.
3030
"""
3131

32-
def __init__(self, name, default, lock):
33-
# type: (string, Any, threading.Lock) -> None
32+
def __init__(self, name, default):
33+
# type: (string, Any) -> None
3434
pass
3535

3636
def clear(self):
@@ -54,11 +54,10 @@ class _AsyncContext(object):
5454
Uses contextvars to set and get variables globally in a thread safe way.
5555
"""
5656

57-
def __init__(self, name, default, lock):
57+
def __init__(self, name, default):
5858
self.name = name
5959
self.contextvar = contextvars.ContextVar(name)
6060
self.default = default if callable(default) else (lambda: default)
61-
self.lock = lock
6261

6362
def clear(self):
6463
# type: () -> None
@@ -78,8 +77,7 @@ def get(self):
7877
def set(self, value):
7978
# type: (Any) -> None
8079
"""Set the value in the context."""
81-
with self.lock:
82-
self.contextvar.set(value)
80+
self.contextvar.set(value)
8381

8482

8583
class _ThreadLocalContext(object):
@@ -88,11 +86,10 @@ class _ThreadLocalContext(object):
8886
"""
8987
_thread_local = threading.local()
9088

91-
def __init__(self, name, default, lock):
92-
# type: (str, Any, threading.Lock) -> None
89+
def __init__(self, name, default):
90+
# type: (str, Any) -> None
9391
self.name = name
9492
self.default = default if callable(default) else (lambda: default)
95-
self.lock = lock
9693

9794
def clear(self):
9895
# type: () -> None
@@ -112,16 +109,14 @@ def get(self):
112109
def set(self, value):
113110
# type: (Any) -> None
114111
"""Set the value in the context."""
115-
with self.lock:
116-
setattr(self._thread_local, self.name, value)
112+
setattr(self._thread_local, self.name, value)
117113

118114

119-
class TracingContext:
120-
_lock = threading.Lock()
121-
115+
class TracingContext(object):
122116
def __init__(self):
123117
# type: () -> None
124-
self.current_span = TracingContext._get_context_class("current_span", None)
118+
context_class = _AsyncContext if contextvars else _ThreadLocalContext
119+
self.current_span = context_class("current_span", None)
125120

126121
def with_current_context(self, func):
127122
# type: (Callable[[Any], Any]) -> Any
@@ -146,17 +141,4 @@ def call_with_current_context(*args, **kwargs):
146141

147142
return call_with_current_context
148143

149-
@classmethod
150-
def _get_context_class(cls, name, default_val):
151-
# type: (str, Any) -> ContextProtocol
152-
"""
153-
Returns an instance of the the context class that stores the variable.
154-
:param name: The key to store the variable in the context class
155-
:param default_val: The default value of the variable if unset
156-
:return: An instance that implements the context protocol class
157-
"""
158-
context_class = _AsyncContext if contextvars else _ThreadLocalContext
159-
return context_class(name, default_val, cls._lock)
160-
161-
162144
tracing_context = TracingContext()

sdk/core/azure-core/tests/test_tracing_context.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
3434

3535

3636
class TestContext(unittest.TestCase):
37-
def test_get_context_class(self):
38-
with ContextHelper():
39-
slot = tracing_context._get_context_class("temp", 1)
40-
assert slot.get() == 1
41-
slot.set(2)
42-
assert slot.get() == 2
43-
4437
def test_current_span(self):
4538
with ContextHelper():
46-
assert tracing_context.current_span.get() is None
39+
assert not tracing_context.current_span.get()
4740
val = mock.Mock(spec=AbstractSpan)
4841
tracing_context.current_span.set(val)
4942
assert tracing_context.current_span.get() == val
43+
tracing_context.current_span.clear()
44+
assert not tracing_context.current_span.get()
5045

5146
def test_with_current_context(self):
5247
with ContextHelper(tracer_to_use=mock.Mock(AbstractSpan)):

0 commit comments

Comments
 (0)