@@ -29,8 +29,8 @@ class ContextProtocol(Protocol):
2929 Implements set and get variables in a thread safe way.
3030 """
3131
32- def __init__ (self , name , default , lock ):
33- # type: (string, Any, threading.Lock ) -> None
32+ def __init__ (self , name , default ):
33+ # type: (string, Any) -> None
3434 pass
3535
3636 def clear (self ):
@@ -54,11 +54,10 @@ class _AsyncContext(object):
5454 Uses contextvars to set and get variables globally in a thread safe way.
5555 """
5656
57- def __init__ (self , name , default , lock ):
57+ def __init__ (self , name , default ):
5858 self .name = name
5959 self .contextvar = contextvars .ContextVar (name )
6060 self .default = default if callable (default ) else (lambda : default )
61- self .lock = lock
6261
6362 def clear (self ):
6463 # type: () -> None
@@ -78,8 +77,7 @@ def get(self):
7877 def set (self , value ):
7978 # type: (Any) -> None
8079 """Set the value in the context."""
81- with self .lock :
82- self .contextvar .set (value )
80+ self .contextvar .set (value )
8381
8482
8583class _ThreadLocalContext (object ):
@@ -88,11 +86,10 @@ class _ThreadLocalContext(object):
8886 """
8987 _thread_local = threading .local ()
9088
91- def __init__ (self , name , default , lock ):
92- # type: (str, Any, threading.Lock ) -> None
89+ def __init__ (self , name , default ):
90+ # type: (str, Any) -> None
9391 self .name = name
9492 self .default = default if callable (default ) else (lambda : default )
95- self .lock = lock
9693
9794 def clear (self ):
9895 # type: () -> None
@@ -112,16 +109,14 @@ def get(self):
112109 def set (self , value ):
113110 # type: (Any) -> None
114111 """Set the value in the context."""
115- with self .lock :
116- setattr (self ._thread_local , self .name , value )
112+ setattr (self ._thread_local , self .name , value )
117113
118114
119- class TracingContext :
120- _lock = threading .Lock ()
121-
115+ class TracingContext (object ):
122116 def __init__ (self ):
123117 # type: () -> None
124- self .current_span = TracingContext ._get_context_class ("current_span" , None )
118+ context_class = _AsyncContext if contextvars else _ThreadLocalContext
119+ self .current_span = context_class ("current_span" , None )
125120
126121 def with_current_context (self , func ):
127122 # type: (Callable[[Any], Any]) -> Any
@@ -146,17 +141,4 @@ def call_with_current_context(*args, **kwargs):
146141
147142 return call_with_current_context
148143
149- @classmethod
150- def _get_context_class (cls , name , default_val ):
151- # type: (str, Any) -> ContextProtocol
152- """
153- Returns an instance of the the context class that stores the variable.
154- :param name: The key to store the variable in the context class
155- :param default_val: The default value of the variable if unset
156- :return: An instance that implements the context protocol class
157- """
158- context_class = _AsyncContext if contextvars else _ThreadLocalContext
159- return context_class (name , default_val , cls ._lock )
160-
161-
162144tracing_context = TracingContext ()
0 commit comments