Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3096f9f
adding functionality to require tool calling
adrianlyjak May 7, 2025
a480059
add anthropic
adrianlyjak May 14, 2025
b4a4e75
add azure fix
adrianlyjak May 14, 2025
c3178f5
update bedrock
adrianlyjak May 14, 2025
9723fb9
add cohere tests
adrianlyjak May 14, 2025
1dd85e7
gemini tests
adrianlyjak May 15, 2025
3b0c3ea
test huggingface api
adrianlyjak May 15, 2025
53babee
test ibm watsonx
adrianlyjak May 15, 2025
c5399a6
test litellm
adrianlyjak May 15, 2025
c603f95
add mistral tests
adrianlyjak May 15, 2025
720be16
add oci data science tests
adrianlyjak May 23, 2025
381c492
add tests for oci pt 2
adrianlyjak May 23, 2025
edb097f
test/fix openai
adrianlyjak May 23, 2025
f6a88cd
implement tool_config for verteex ai
adrianlyjak May 23, 2025
63dac3c
test vertex
adrianlyjak May 23, 2025
aea02db
clarify ollama
adrianlyjak May 23, 2025
a9181c3
fix gemini imports
adrianlyjak May 23, 2025
a1297e8
fix formats
adrianlyjak May 27, 2025
9cec0ae
fix bedrock test
adrianlyjak May 28, 2025
c9f1851
Fix issues in gemini structured predict, and nicen up the tests
adrianlyjak May 28, 2025
07567d7
tintegration est mistral
adrianlyjak May 28, 2025
7740da9
integration test anthropic
adrianlyjak May 28, 2025
c3e3b52
integration test cohere
adrianlyjak May 28, 2025
d64f9a6
Add openai integration tests
adrianlyjak May 28, 2025
3ef48a4
missed add google genai tool required. Implement it and add unit/inte…
adrianlyjak May 28, 2025
f9b4bad
Update the one reference to tool_choice
adrianlyjak May 28, 2025
ca3770d
version bump all of the affected llms
adrianlyjak May 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/understanding/extraction/lower_level.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ resp = llm.chat_with_tools(
[tool],
# chat_history=chat_history, # can optionally pass in chat history instead of user_msg
user_msg="Extract an invoice from the following text: " + text,
# tool_choice="Invoice", # can optionally force the tool call
tool_required=True, # can optionally force the tool call
)

tool_calls = llm.get_tool_calls_from_response(
Expand Down
9 changes: 9 additions & 0 deletions llama-index-core/llama_index/core/llms/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False, # if required, LLM should only call tools, and not return a response
**kwargs: Any,
) -> ChatResponse:
"""Chat with function calling."""
Expand All @@ -42,6 +43,7 @@ def chat_with_tools(
chat_history=chat_history,
verbose=verbose,
allow_parallel_tool_calls=allow_parallel_tool_calls,
tool_required=tool_required,
**kwargs,
)
response = self.chat(**chat_kwargs)
Expand All @@ -59,6 +61,7 @@ async def achat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponse:
"""Async chat with function calling."""
Expand All @@ -68,6 +71,7 @@ async def achat_with_tools(
chat_history=chat_history,
verbose=verbose,
allow_parallel_tool_calls=allow_parallel_tool_calls,
tool_required=tool_required,
**kwargs,
)
response = await self.achat(**chat_kwargs)
Expand All @@ -85,6 +89,7 @@ def stream_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponseGen:
"""Stream chat with function calling."""
Expand All @@ -94,6 +99,7 @@ def stream_chat_with_tools(
chat_history=chat_history,
verbose=verbose,
allow_parallel_tool_calls=allow_parallel_tool_calls,
tool_required=tool_required,
**kwargs,
)
# TODO: no validation for streaming outputs
Expand All @@ -106,6 +112,7 @@ async def astream_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponseAsyncGen:
"""Async stream chat with function calling."""
Expand All @@ -115,6 +122,7 @@ async def astream_chat_with_tools(
chat_history=chat_history,
verbose=verbose,
allow_parallel_tool_calls=allow_parallel_tool_calls,
tool_required=tool_required,
**kwargs,
)
# TODO: no validation for streaming outputs
Expand All @@ -128,6 +136,7 @@ def _prepare_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False, # if required, LLM should only call tools, and not return a response
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the arguments needed to let the LLM chat with tools."""
Expand Down
3 changes: 3 additions & 0 deletions llama-index-core/tests/agent/function_calling/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _prepare_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare chat with tools."""
Expand All @@ -109,6 +110,7 @@ def chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponse:
return ChatResponse(message=ChatMessage(role="user", content=""))
Expand All @@ -120,6 +122,7 @@ async def achat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponse:
return ChatResponse(message=ChatMessage(role="user", content=""))
Expand Down
1 change: 1 addition & 0 deletions llama-index-core/tests/llms/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _prepare_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
return {"messages": []}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def _prepare_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False, # ai21 does not support configuring the tool_choice
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise an error if its set to True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@logan-markewich My thinking is that this will frequently get set more by internal library code rather than user code. Like, we'll want to migrate StructuredLLM to set tool_required=True (it's currently setting tool_choice="function_name" or something and hoping for the best from the LLM implementation). Seemed better to have it maybe give a tool response rather than blow up, like it currently is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise the alternative is to somehow advertise whether the LLM supports tool_required or not, and checking that before providing it, which seems like a lot of gymnastics for mostly just a few underused LLMs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair!

**kwargs: Any,
) -> Dict[str, Any]:
tool_specs = [tool.metadata.to_openai_tool() for tool in tools]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dev = [

[project]
name = "llama-index-llms-ai21"
version = "0.4.0"
version = "0.5.0"
description = "llama-index llms ai21 integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,13 +641,22 @@ async def gen() -> AsyncGenerator[AnthropicCompletionResponse, None]:

return gen()

def _map_tool_choice_to_anthropic(
self, tool_required: bool, allow_parallel_tool_calls: bool
) -> dict:
return {
"disable_parallel_tool_use": not allow_parallel_tool_calls,
"type": "any" if tool_required else "auto",
}

def _prepare_chat_with_tools(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the chat with tools."""
Expand All @@ -672,7 +681,14 @@ def _prepare_chat_with_tools(
):
tool_dicts[-1]["cache_control"] = {"type": "ephemeral"}

return {"messages": chat_history, "tools": tool_dicts, **kwargs}
return {
"messages": chat_history,
"tools": tool_dicts,
"tool_choice": self._map_tool_choice_to_anthropic(
tool_required, allow_parallel_tool_calls
),
**kwargs,
}

def _validate_chat_with_tools_response(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dev = [

[project]
name = "llama-index-llms-anthropic"
version = "0.6.19"
version = "0.7.0"
description = "llama-index llms anthropic integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
MessageRole,
ChatResponse,
)
from llama_index.core.tools import FunctionTool
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.anthropic.base import AnthropicChatResponse


def test_text_inference_embedding_class():
Expand Down Expand Up @@ -225,6 +227,36 @@ def pdf_url() -> str:
return "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"


@pytest.mark.skipif(
os.getenv("ANTHROPIC_API_KEY") is None,
reason="Anthropic API key not available to test Anthropic integration",
)
def test_tool_required():
llm = Anthropic(model="claude-3-5-sonnet-latest")

search_tool = FunctionTool.from_defaults(fn=search)

# Test with tool_required=True
response = llm.chat_with_tools(
user_msg="What is the weather in Paris?",
tools=[search_tool],
tool_required=True,
)
assert isinstance(response, AnthropicChatResponse)
assert response.message.additional_kwargs["tool_calls"] is not None
assert len(response.message.additional_kwargs["tool_calls"]) > 0

# Test with tool_required=False
response = llm.chat_with_tools(
user_msg="Say hello!",
tools=[search_tool],
tool_required=False,
)
assert isinstance(response, AnthropicChatResponse)
# Should not use tools for a simple greeting
assert not response.message.additional_kwargs.get("tool_calls")


@pytest.mark.skipif(
os.getenv("ANTHROPIC_API_KEY") is None,
reason="Anthropic API key not available to test Anthropic document uploading ",
Expand All @@ -244,3 +276,65 @@ def test_document_upload(tmp_path: Path, pdf_url: str) -> None:
messages = [msg]
response = llm.chat(messages)
assert isinstance(response, ChatResponse)


def test_map_tool_choice_to_anthropic():
"""Test that tool_required is correctly mapped to Anthropic's tool_choice parameter."""
llm = Anthropic()

# Test with tool_required=True
tool_choice = llm._map_tool_choice_to_anthropic(
tool_required=True, allow_parallel_tool_calls=False
)
assert tool_choice["type"] == "any"
assert tool_choice["disable_parallel_tool_use"]

# Test with tool_required=False
tool_choice = llm._map_tool_choice_to_anthropic(
tool_required=False, allow_parallel_tool_calls=False
)
assert tool_choice["type"] == "auto"
assert tool_choice["disable_parallel_tool_use"]

# Test with allow_parallel_tool_calls=True
tool_choice = llm._map_tool_choice_to_anthropic(
tool_required=True, allow_parallel_tool_calls=True
)
assert tool_choice["type"] == "any"
assert not tool_choice["disable_parallel_tool_use"]


def search(query: str) -> str:
"""Search for information about a query."""
return f"Results for {query}"


search_tool = FunctionTool.from_defaults(
fn=search, name="search_tool", description="A tool for searching information"
)


def test_prepare_chat_with_tools_tool_required():
"""Test that tool_required is correctly passed to the API request when True."""
llm = Anthropic()

# Test with tool_required=True
result = llm._prepare_chat_with_tools(tools=[search_tool], tool_required=True)

assert result["tool_choice"]["type"] == "any"
assert len(result["tools"]) == 1
assert result["tools"][0]["name"] == "search_tool"


def test_prepare_chat_with_tools_tool_not_required():
"""Test that tool_required is correctly passed to the API request when False."""
llm = Anthropic()

# Test with tool_required=False (default)
result = llm._prepare_chat_with_tools(
tools=[search_tool],
)

assert result["tool_choice"]["type"] == "auto"
assert len(result["tools"]) == 1
assert result["tools"][0]["name"] == "search_tool"
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@

from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.aio import ChatCompletionsClient as ChatCompletionsClientAsync
from azure.ai.inference.models import (
ChatCompletionsToolChoicePreset,
ChatCompletionsNamedToolChoice,
)

if TYPE_CHECKING:
from llama_index.core.tools.types import BaseTool
from llama_index.core.chat_engine.types import AgentChatResponse
from azure.core.credentials import TokenCredential


from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from azure.ai.inference.models import (
Expand Down Expand Up @@ -357,6 +362,16 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
raw=response.as_dict(),
)

def _to_azure_tool_choice(
self, tool_required: bool
) -> Optional[
Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]
]:
if tool_required:
return ChatCompletionsToolChoicePreset.REQUIRED
else:
return ChatCompletionsToolChoicePreset.AUTO

@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
Expand Down Expand Up @@ -463,6 +478,7 @@ def chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponse:
"""Predict and call the tool."""
Expand All @@ -481,6 +497,7 @@ def chat_with_tools(
response = self.chat(
messages,
tools=tool_specs,
tool_choice=self._to_azure_tool_choice(tool_required),
**kwargs,
)
if not allow_parallel_tool_calls:
Expand All @@ -494,6 +511,7 @@ async def achat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> ChatResponse:
"""Predict and call the tool."""
Expand All @@ -512,6 +530,7 @@ async def achat_with_tools(
response = await self.achat(
messages,
tools=tool_specs,
tool_choice=self._to_azure_tool_choice(tool_required),
**kwargs,
)
if not allow_parallel_tool_calls:
Expand Down Expand Up @@ -561,6 +580,7 @@ def _prepare_chat_with_tools(
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_required: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the arguments needed to let the LLM chat with tools."""
Expand All @@ -575,5 +595,6 @@ def _prepare_chat_with_tools(
return {
"messages": chat_history,
"tools": tool_dicts or None,
"tool_choice": self._to_azure_tool_choice(tool_required),
**kwargs,
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dev = [

[project]
name = "llama-index-llms-azure-inference"
version = "0.3.0"
version = "0.4.0"
description = "Integration for model supporting Azure AI model inference API in llama-index"
authors = [{name = "Azure AI model inference group", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand Down
Loading