diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 821939d1a609c..cabe04b6a94d1 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -2379,6 +2379,7 @@ def _configure( if run_tree is not None else tracing_context["client"] ), + tags=tracing_tags, ) callback_manager.add_handler(handler, True) except Exception as e: diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index 3694db7188b06..d36adc9bc8eaf 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -13,8 +13,8 @@ ) from uuid import UUID +from langsmith import run_helpers as ls_rh from langsmith import utils as ls_utils -from langsmith.run_helpers import get_run_tree_context from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.run_collector import RunCollectorCallbackHandler @@ -149,7 +149,10 @@ def _tracing_v2_is_enabled() -> Union[bool, Literal["local"]]: def _get_tracer_project() -> str: - run_tree = get_run_tree_context() + tracing_context = ls_rh.get_tracing_context() + run_tree = tracing_context["parent"] + if run_tree is None and tracing_context["project_name"] is not None: + return tracing_context["project_name"] return getattr( run_tree, "session_name", diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index af6a5e84ce9ee..0743929f86120 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -35,6 +35,25 @@ def _get_posts(client: Client) -> list: return posts +def test_tracing_context() -> None: + mock_session = MagicMock() + mock_client_ = Client( + session=mock_session, api_key="test", auto_batch_tracing=False + ) + + @RunnableLambda + def my_function(a: int) -> int: + return a + 1 + + name = uuid.uuid4().hex + project_name = f"Some project {name}" + with tracing_context(project_name=project_name, client=mock_client_, enabled=True): + assert my_function.invoke(1) == 2 + posts = _get_posts(mock_client_) + assert posts + assert all(post["session_name"] == project_name for post in posts) + + def test_config_traceable_handoff() -> None: get_env_var.cache_clear() mock_session = MagicMock()