Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain_v1/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__version__ = "1.0.0a4"


def __getattr__(name: str) -> Any: # noqa: ANN401
def __getattr__(name: str) -> Any:
"""Get an attribute from the package.

TODO: will be removed in a future alpha version.
Expand Down
6 changes: 3 additions & 3 deletions libs/langchain_v1/langchain/_internal/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def resolve_prompt(
prompt: str | None | Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
prompt: str | Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]] | None,
state: StateT,
runtime: Runtime[ContextT],
default_user_content: str,
Expand Down Expand Up @@ -86,9 +86,9 @@ def custom_prompt(state, runtime):

async def aresolve_prompt(
prompt: str
| None
| Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]]
| Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]],
| Callable[[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]]
| None,
state: StateT,
runtime: Runtime[ContextT],
default_user_content: str,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain_v1/langchain/agents/interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ class HumanResponse(TypedDict):
"""

type: Literal["accept", "ignore", "response", "edit"]
args: None | str | ActionRequest
args: str | ActionRequest | None
14 changes: 7 additions & 7 deletions libs/langchain_v1/langchain/agents/middleware/prompt_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,47 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):

def __init__(
self,
type: Literal["ephemeral"] = "ephemeral",
cache_type: Literal["ephemeral"] = "ephemeral",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's still OK to do these breaking changes as it was not released yet ?

ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
) -> None:
"""Initialize the middleware with cache control settings.

Args:
type: The type of cache to use, only "ephemeral" is supported.
cache_type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are supported.
min_messages_to_cache: The minimum number of messages until the cache is used,
default is 0.
"""
self.type = type
self.cache_type = cache_type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache

def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
except ImportError as e:
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
"Anthropic models."
"Please install langchain-anthropic."
)
raise ValueError(msg)
raise ValueError(msg) from e

if not isinstance(request.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
)
raise ValueError(msg)
raise TypeError(msg)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's still OK to do these breaking changes as it was not released yet ?


messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return request

request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
request.model_settings["cache_control"] = {"type": self.cache_type, "ttl": self.ttl}

return request
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
<messages>
Messages to summarize:
{messages}
</messages>""" # noqa: E501
</messages>"""

SUMMARY_PREFIX = "## Previous conversation summary:"

Expand Down Expand Up @@ -229,7 +229,7 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
except Exception as e: # noqa: BLE001
except Exception as e:
return f"Error generating summary: {e!s}"

def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
Expand All @@ -244,5 +244,5 @@ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMess
allow_partial=True,
include_system=True,
)
except Exception: # noqa: BLE001
except Exception:
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
6 changes: 3 additions & 3 deletions libs/langchain_v1/langchain/agents/middleware_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _handle_structured_output_error(
ResponseT = TypeVar("ResponseT")


def create_agent( # noqa: PLR0915
def create_agent(
*,
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
Expand Down Expand Up @@ -278,13 +278,13 @@ def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str,
],
"response": structured_response,
}
except Exception as exc: # noqa: BLE001
except Exception as exc:
exception = StructuredOutputValidationError(tool_call["name"], exc)
should_retry, error_message = _handle_structured_output_error(
exception, response_format
)
if not should_retry:
raise exception
raise exception from exc

return {
"messages": [
Expand Down
14 changes: 7 additions & 7 deletions libs/langchain_v1/langchain/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def _get_prompt_runnable(prompt: Prompt | None) -> Runnable:
lambda state: _get_state_value(state, "messages"), name=PROMPT_RUNNABLE_NAME
)
elif isinstance(prompt, str):
_system_message: BaseMessage = SystemMessage(content=prompt)
system_message: BaseMessage = SystemMessage(content=prompt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the motivation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ruff rule RUF052
Was done by auto-fix.

prompt_runnable = RunnableCallable(
lambda state: [_system_message, *_get_state_value(state, "messages")],
lambda state: [system_message, *_get_state_value(state, "messages")],
name=PROMPT_RUNNABLE_NAME,
)
elif isinstance(prompt, SystemMessage):
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(
"The `model` parameter should not have pre-bound tools, "
"simply pass the model and tools separately."
)
raise ValueError(msg)
raise TypeError(msg)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's still OK to do these breaking changes as it was not released yet ?


self._setup_tools()
self._setup_state_schema()
Expand Down Expand Up @@ -397,13 +397,13 @@ def _handle_single_structured_output(
"structured_response": structured_response,
}
)
except Exception as exc: # noqa: BLE001
except Exception as exc:
exception = StructuredOutputValidationError(tool_call["name"], exc)

should_retry, error_message = self._handle_structured_output_error(exception)

if not should_retry:
raise exception
raise exception from exc

return Command(
update={
Expand Down Expand Up @@ -583,7 +583,7 @@ def _are_more_steps_needed(state: StateT, response: BaseMessage) -> bool:
remaining_steps is not None # type: ignore[return-value]
and (
(remaining_steps < 1 and all_tools_return_direct)
or (remaining_steps < 2 and has_tool_calls)
or (remaining_steps < 2 and has_tool_calls) # noqa: PLR2004
)
)

Expand Down Expand Up @@ -1188,7 +1188,7 @@ def check_weather(location: str) -> str:
response_format = ToolStrategy(
schema=response_format,
)
elif isinstance(response_format, tuple) and len(response_format) == 2:
elif isinstance(response_format, tuple) and len(response_format) == 2: # noqa: PLR2004
msg = "Passing a 2-tuple as response_format is no longer supported. "
raise ValueError(msg)

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain_v1/langchain/agents/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(

def _iter_variants(schema: Any) -> Iterable[Any]:
"""Yield leaf variants from Union and JSON Schema oneOf."""
if get_origin(schema) in (UnionType, Union):
if get_origin(schema) in {UnionType, Union}:
for arg in get_args(schema):
yield from _iter_variants(arg)
return
Expand Down
58 changes: 28 additions & 30 deletions libs/langchain_v1/langchain/agents/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def msg_content_output(output: Any) -> str | list[dict]:
# any existing ToolNode usage.
try:
return json.dumps(output, ensure_ascii=False)
except Exception: # noqa: BLE001
except Exception:
return str(output)


Expand Down Expand Up @@ -207,7 +207,7 @@ def _handle_tool_error(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {flag}"
)
raise ValueError(msg)
raise TypeError(msg)
return content


Expand Down Expand Up @@ -239,7 +239,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
params = list(sig.parameters.values())
if params:
# If it's a method, the first argument is typically 'self' or 'cls'
if params[0].name in ["self", "cls"] and len(params) == 2:
if params[0].name in {"self", "cls"} and len(params) == 2: # noqa: PLR2004
first_param = params[1]
else:
first_param = params[0]
Expand Down Expand Up @@ -378,7 +378,7 @@ def handle_errors(e: ValueError) -> str:

tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
```
""" # noqa: E501
"""

name: str = "tools"

Expand Down Expand Up @@ -426,12 +426,12 @@ def tools_by_name(self) -> dict[str, BaseTool]:

def _func(
self,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
input_: list[AnyMessage] | dict[str, Any] | BaseModel,
config: RunnableConfig,
*,
store: Optional[BaseStore], # noqa: UP045
) -> Any:
tool_calls, input_type = self._parse_input(input, store)
tool_calls, input_type = self._parse_input(input_, store)
config_list = get_config_list(config, len(tool_calls))
input_types = [input_type] * len(tool_calls)
with get_executor_for_config(config) as executor:
Expand All @@ -441,12 +441,12 @@ def _func(

async def _afunc(
self,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
input_: list[AnyMessage] | dict[str, Any] | BaseModel,
config: RunnableConfig,
*,
store: Optional[BaseStore], # noqa: UP045
) -> Any:
tool_calls, input_type = self._parse_input(input, store)
tool_calls, input_type = self._parse_input(input_, store)
outputs = await asyncio.gather(
*(self._arun_one(call, input_type, config) for call in tool_calls)
)
Expand Down Expand Up @@ -630,20 +630,20 @@ async def _arun_one(

def _parse_input(
self,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
input_: list[AnyMessage] | dict[str, Any] | BaseModel,
store: BaseStore | None,
) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
input_type: Literal["list", "dict", "tool_calls"]
if isinstance(input, list):
if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call":
if isinstance(input_, list):
if isinstance(input_[-1], dict) and input_[-1].get("type") == "tool_call":
input_type = "tool_calls"
tool_calls = cast("list[ToolCall]", input)
tool_calls = cast("list[ToolCall]", input_)
return tool_calls, input_type
input_type = "list"
messages = input
elif isinstance(input, dict) and (messages := input.get(self._messages_key, [])):
messages = input_
elif isinstance(input_, dict) and (messages := input_.get(self._messages_key, [])):
input_type = "dict"
elif messages := getattr(input, self._messages_key, []):
elif messages := getattr(input_, self._messages_key, []):
# Assume dataclass-like state that can coerce from dict
input_type = "dict"
else:
Expand All @@ -652,12 +652,12 @@ def _parse_input(

try:
latest_ai_message = next(m for m in reversed(messages) if isinstance(m, AIMessage))
except StopIteration:
except StopIteration as e:
msg = "No AIMessage found in input"
raise ValueError(msg)
raise ValueError(msg) from e

tool_calls = [
self.inject_tool_args(call, input, store) for call in latest_ai_message.tool_calls
self.inject_tool_args(call, input_, store) for call in latest_ai_message.tool_calls
]
return tool_calls, input_type

Expand All @@ -677,15 +677,15 @@ def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None:
def _inject_state(
self,
tool_call: ToolCall,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
input_: list[AnyMessage] | dict[str, Any] | BaseModel,
) -> ToolCall:
state_args = self._tool_to_state_args[tool_call["name"]]
if state_args and isinstance(input, list):
if state_args and isinstance(input_, list):
required_fields = list(state_args.values())
if (
len(required_fields) == 1 and required_fields[0] == self._messages_key
) or required_fields[0] is None:
input = {self._messages_key: input}
input_ = {self._messages_key: input_}
else:
err_msg = (
f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
Expand All @@ -696,14 +696,14 @@ def _inject_state(
err_msg += f" State should contain fields {required_fields_str}."
raise ValueError(err_msg)

if isinstance(input, dict):
if isinstance(input_, dict):
tool_state_args = {
tool_arg: input[state_field] if state_field else input
tool_arg: input_[state_field] if state_field else input_
for tool_arg, state_field in state_args.items()
}
else:
tool_state_args = {
tool_arg: getattr(input, state_field) if state_field else input
tool_arg: getattr(input_, state_field) if state_field else input_
for tool_arg, state_field in state_args.items()
}

Expand Down Expand Up @@ -734,7 +734,7 @@ def _inject_store(self, tool_call: ToolCall, store: BaseStore | None) -> ToolCal
def inject_tool_args(
self,
tool_call: ToolCall,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
input_: list[AnyMessage] | dict[str, Any] | BaseModel,
store: BaseStore | None,
) -> ToolCall:
"""Inject graph state and store into tool call arguments.
Expand All @@ -751,7 +751,7 @@ def inject_tool_args(
Args:
tool_call: The tool call dictionary to augment with injected arguments.
Must contain 'name', 'args', 'id', and 'type' fields.
input: The current graph state to inject into tools requiring state access.
input_: The current graph state to inject into tools requiring state access.
Can be a message list, state dictionary, or BaseModel instance.
store: The persistent store instance to inject into tools requiring storage.
Will be None if no store is configured for the graph.
Expand All @@ -774,7 +774,7 @@ def inject_tool_args(
return tool_call

tool_call_copy: ToolCall = copy(tool_call)
tool_call_with_state = self._inject_state(tool_call_copy, input)
tool_call_with_state = self._inject_state(tool_call_copy, input_)
return self._inject_store(tool_call_with_state, store)

def _validate_tool_command(
Expand All @@ -786,7 +786,7 @@ def _validate_tool_command(
if isinstance(command.update, dict):
# input type is dict when ToolNode is invoked with a dict input
# (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
if input_type not in ("dict", "tool_calls"):
if input_type not in {"dict", "tool_calls"}:
msg = (
"Tools can provide a dict in Command.update only when using dict "
f"with '{self._messages_key}' key as ToolNode input, "
Expand Down Expand Up @@ -1141,8 +1141,6 @@ def _get_state_args(tool: BaseTool) -> dict[str, str | None]:
tool_args_to_state_fields[name] = injection.field
else:
tool_args_to_state_fields[name] = None
else:
pass
return tool_args_to_state_fields


Expand Down
Loading