From 8e47c02fa25583a72c33f2058541d0fff98cbe64 Mon Sep 17 00:00:00 2001 From: juchengquan Date: Thu, 10 Jul 2025 17:43:30 +0800 Subject: [PATCH] Fix ReadError in during async concurrent requests and related modules Bug fixes: - Replace `self._client_creator` with `self._client` to avoid crash in async.gather when setting up proxies in `proxy_mounts` Code refraction: - Removed unused imports and added better module export. - Improved exception handling specifying ImportError for cohere client in `hybrid_rag.py`. - Standardization of string quotes and formatting for enhanced code readability and consistency. - Updated Tavily Client constructor parameters to use Optional types for better type hinting. - Cleaned up code formatting and removed unnecessary comments for clarity. --- setup.py | 28 +- tavily/__init__.py | 9 +- tavily/async_tavily.py | 646 ++++++++++++++++---------------- tavily/config.py | 26 +- tavily/errors.py | 2 - tavily/hybrid_rag/__init__.py | 6 +- tavily/hybrid_rag/hybrid_rag.py | 238 +++++++----- tavily/tavily.py | 601 +++++++++++++++-------------- tavily/utils.py | 8 +- 9 files changed, 818 insertions(+), 746 deletions(-) diff --git a/setup.py b/setup.py index b0e105e..4153897 100644 --- a/setup.py +++ b/setup.py @@ -1,23 +1,23 @@ from setuptools import setup, find_packages -with open('README.md', 'r', encoding='utf-8') as f: +with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() setup( - name='tavily-python', - version='0.7.9', - url='https://github.com/tavily-ai/tavily-python', - author='Tavily AI', - author_email='support@tavily.com', - description='Python wrapper for the Tavily API', + name="tavily-python", + version="0.7.9", + url="https://github.com/tavily-ai/tavily-python", + author="Tavily AI", + author_email="support@tavily.com", + description="Python wrapper for the Tavily API", long_description=long_description, - long_description_content_type='text/markdown', - packages=find_packages(exclude=['tests']), - install_requires=['requests', 'tiktoken>=0.5.1', 'httpx'], + long_description_content_type="text/markdown", + packages=find_packages(exclude=["tests"]), + install_requires=["requests", "tiktoken>=0.5.1", "httpx"], classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires=">=3.6", ) diff --git a/tavily/__init__.py b/tavily/__init__.py index 4a2ea54..251371d 100644 --- a/tavily/__init__.py +++ b/tavily/__init__.py @@ -1,4 +1,11 @@ from .async_tavily import AsyncTavilyClient from .tavily import Client, TavilyClient from .errors import InvalidAPIKeyError, UsageLimitExceededError, MissingAPIKeyError, BadRequestError -from .hybrid_rag import TavilyHybridClient \ No newline at end of file +from .hybrid_rag import TavilyHybridClient + +__all__ = [ + "AsyncTavilyClient", + "Client", "TavilyClient", + "InvalidAPIKeyError", "UsageLimitExceededError", "MissingAPIKeyError", "BadRequestError", + "TavilyHybridClient" +] \ No newline at end of file diff --git a/tavily/async_tavily.py b/tavily/async_tavily.py index 0013aef..284ad46 100644 --- a/tavily/async_tavily.py +++ b/tavily/async_tavily.py @@ -1,7 +1,7 @@ import asyncio import json import os -from typing import Literal, Sequence, Optional, List, Union +from typing import Literal, Sequence, Optional, List, Union, cast import httpx @@ -15,10 +15,14 @@ class AsyncTavilyClient: Async Tavily API client class. """ - def __init__(self, api_key: Optional[str] = None, - company_info_tags: Sequence[str] = ("news", "general", "finance"), - proxies: Optional[dict[str, str]] = None, - api_base_url: Optional[str] = None): + def __init__( + self, + api_key: Optional[str] = None, + company_info_tags: Sequence[str] = ("news", "general", "finance"), + proxies: Optional[dict[str, str]] = None, + api_base_url: Optional[str] = None, + verify: bool = True, + ): if api_key is None: api_key = os.getenv("TAVILY_API_KEY") @@ -34,42 +38,36 @@ def __init__(self, api_key: Optional[str] = None, mapped_proxies = {key: value for key, value in mapped_proxies.items() if value} - proxy_mounts = ( - {scheme: httpx.AsyncHTTPTransport(proxy=proxy) for scheme, proxy in mapped_proxies.items()} - if mapped_proxies - else None - ) - self._api_base_url = api_base_url or "https://api.tavily.com" self._client_creator = lambda: httpx.AsyncClient( - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - "X-Client-Source": "tavily-python" - }, + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "X-Client-Source": "tavily-python"}, base_url=self._api_base_url, - mounts=proxy_mounts + mounts=( + {scheme: httpx.AsyncHTTPTransport(proxy=proxy, verify=verify) for scheme, proxy in mapped_proxies.items()} + if mapped_proxies + else None + ), ) self._company_info_tags = company_info_tags async def _search( - self, - query: str, - search_depth: Literal["basic", "advanced"] = None, - topic: Literal["general", "news", "finance"] = None, - time_range: Literal["day", "week", "month", "year"] = None, - days: int = None, - max_results: int = None, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - include_answer: Union[bool, Literal["basic", "advanced"]] = None, - include_raw_content: Union[bool, Literal["markdown", "text"]] = None, - include_images: bool = None, - timeout: int = 60, - country: str = None, - auto_parameters: bool = None, - include_favicon: bool = None, - **kwargs, + self, + query: str, + search_depth: Optional[Literal["basic", "advanced"]] = None, + topic: Optional[Literal["general", "news", "finance"]] = None, + time_range: Optional[Literal["day", "week", "month", "year"]] = None, + days: Optional[int] = None, + max_results: Optional[int] = None, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + include_answer: Optional[Union[bool, Literal["basic", "advanced"]]] = None, + include_raw_content: Optional[Union[bool, Literal["markdown", "text"]]] = None, + include_images: Optional[bool] = None, + timeout: int = 60, + country: Optional[str] = None, + auto_parameters: Optional[bool] = None, + include_favicon: Optional[bool] = None, + **kwargs, ) -> dict: """ Internal search method to send the request to the API. @@ -115,54 +113,56 @@ async def _search( if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: - raise response.raise_for_status() - - async def search(self, - query: str, - search_depth: Literal["basic", "advanced"] = None, - topic: Literal["general", "news", "finance"] = None, - time_range: Literal["day", "week", "month", "year"] = None, - days: int = None, - max_results: int = None, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - include_answer: Union[bool, Literal["basic", "advanced"]] = None, - include_raw_content: Union[bool, Literal["markdown", "text"]] = None, - include_images: bool = None, - timeout: int = 60, - country: str = None, - auto_parameters: bool = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + async def search( + self, + query: str, + search_depth: Optional[Literal["basic", "advanced"]] = None, + topic: Optional[Literal["general", "news", "finance"]] = None, + time_range: Optional[Literal["day", "week", "month", "year"]] = None, + days: Optional[int] = None, + max_results: Optional[int] = None, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + include_answer: Optional[Union[bool, Literal["basic", "advanced"]]] = None, + include_raw_content: Optional[Union[bool, Literal["markdown", "text"]]] = None, + include_images: Optional[bool] = None, + timeout: int = 60, + country: Optional[str] = None, + auto_parameters: Optional[bool] = None, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> dict: """ Combined search method. Set search_depth to either "basic" or "advanced". """ timeout = min(timeout, 120) - response_dict = await self._search(query, - search_depth=search_depth, - topic=topic, - time_range=time_range, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=include_answer, - include_raw_content=include_raw_content, - include_images=include_images, - timeout=timeout, - country=country, - auto_parameters=auto_parameters, - include_favicon=include_favicon, - **kwargs, - ) + response_dict = await self._search( + query, + search_depth=search_depth, + topic=topic, + time_range=time_range, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=include_answer, + include_raw_content=include_raw_content, + include_images=include_images, + timeout=timeout, + country=country, + auto_parameters=auto_parameters, + include_favicon=include_favicon, + **kwargs, + ) tavily_results = response_dict.get("results", []) @@ -171,14 +171,14 @@ async def search(self, return response_dict async def _extract( - self, - urls: Union[List[str], str], - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs + self, + urls: Union[List[str], str], + include_images: Optional[bool] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, ) -> dict: """ Internal extract method to send the request to the API. @@ -214,40 +214,41 @@ async def _extract( except Exception: pass - if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: - raise response.raise_for_status() - - async def extract(self, - urls: Union[List[str], str], # Accept a list of URLs or a single URL - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + async def extract( + self, + urls: Union[List[str], str], # Accept a list of URLs or a single URL + include_images: Optional[bool] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> dict: """ Combined extract method. include_favicon: If True, include the favicon in the extraction results. """ timeout = min(timeout, 120) - response_dict = await self._extract(urls, - include_images, - extract_depth, - format, - timeout, - include_favicon=include_favicon, - **kwargs, - ) + response_dict = await self._extract( + urls, + include_images, + extract_depth, + format, + timeout, + include_favicon=include_favicon, + **kwargs, + ) tavily_results = response_dict.get("results", []) failed_results = response_dict.get("failed_results", []) @@ -256,26 +257,27 @@ async def extract(self, response_dict["failed_results"] = failed_results return response_dict - - async def _crawl(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs - ) -> dict: + + async def _crawl( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, + ) -> dict: """ Internal crawl method to send the request to the API. """ @@ -310,86 +312,90 @@ async def _crawl(self, except httpx.TimeoutException: raise TimeoutError(timeout) - if response.status_code == 200: - return response.json() + if response.status_code == 200: + return response.json() + else: + detail = "" + try: + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass + + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403, 432, 433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) else: - detail = "" - try: - detail = response.json().get("detail", {}).get("error", None) - except Exception: - pass - - if response.status_code == 429: - raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: - raise ForbiddenError(detail) - elif response.status_code == 401: - raise InvalidAPIKeyError(detail) - elif response.status_code == 400: - raise BadRequestError(detail) - else: - raise response.raise_for_status() - - async def crawl(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - categories: Sequence[AllowedCategory] = None, - extract_depth: Literal["basic", "advanced"] = None, - include_images: bool = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs - ) -> dict: + raise cast(Exception,response.raise_for_status()) + + async def crawl( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + include_images: Optional[bool] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, + ) -> dict: """ Combined crawl method. - + """ timeout = min(timeout, 120) - response_dict = await self._crawl(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - categories=categories, - extract_depth=extract_depth, - include_images=include_images, - format=format, - timeout=timeout, - include_favicon=include_favicon, - **kwargs) + response_dict = await self._crawl( + url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + categories=categories, + extract_depth=extract_depth, + include_images=include_images, + format=format, + timeout=timeout, + include_favicon=include_favicon, + **kwargs, + ) return response_dict - - async def _map(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - timeout: int = 60, - **kwargs - ) -> dict: + + async def _map( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + timeout: int = 60, + **kwargs, + ) -> dict: """ Internal map method to send the request to the API. """ @@ -421,78 +427,82 @@ async def _map(self, except httpx.TimeoutException: raise TimeoutError(timeout) - if response.status_code == 200: - return response.json() + if response.status_code == 200: + return response.json() + else: + detail = "" + try: + detail = response.json().get("detail", {}).get("error", None) + except Exception: + pass + + if response.status_code == 429: + raise UsageLimitExceededError(detail) + elif response.status_code in [403, 432, 433]: + raise ForbiddenError(detail) + elif response.status_code == 401: + raise InvalidAPIKeyError(detail) + elif response.status_code == 400: + raise BadRequestError(detail) else: - detail = "" - try: - detail = response.json().get("detail", {}).get("error", None) - except Exception: - pass - - if response.status_code == 429: - raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: - raise ForbiddenError(detail) - elif response.status_code == 401: - raise InvalidAPIKeyError(detail) - elif response.status_code == 400: - raise BadRequestError(detail) - else: - raise response.raise_for_status() - - async def map(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - timeout: int = 60, - **kwargs - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + async def map( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + timeout: int = 60, + **kwargs, + ) -> dict: """ Combined map method. """ timeout = min(timeout, 120) - response_dict = await self._map(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - categories=categories, - timeout=timeout, - **kwargs) + response_dict = await self._map( + url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + categories=categories, + timeout=timeout, + **kwargs, + ) return response_dict - async def get_search_context(self, - query: str, - search_depth: Literal["basic", "advanced"] = "basic", - topic: Literal["general", "news", "finance"] = "general", - days: int = 7, - max_results: int = 5, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - max_tokens: int = 4000, - timeout: int = 60, - country: str = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> str: + async def get_search_context( + self, + query: str, + search_depth: Literal["basic", "advanced"] = "basic", + topic: Literal["general", "news", "finance"] = "general", + days: int = 7, + max_results: int = 5, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + max_tokens: int = 4000, + timeout: int = 60, + country: Optional[str] = None, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> str: """ Get the search context for a query. Useful for getting only related content from retrieved websites without having to deal with context extraction and limitation yourself. @@ -502,77 +512,83 @@ async def get_search_context(self, Returns a string of JSON containing the search context up to context limit. """ timeout = min(timeout, 120) - response_dict = await self._search(query, - search_depth=search_depth, - topic=topic, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=False, - include_raw_content=False, - include_images=False, - timeout = timeout, - country=country, - include_favicon=include_favicon, - **kwargs, - ) + response_dict = await self._search( + query, + search_depth=search_depth, + topic=topic, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=False, + include_raw_content=False, + include_images=False, + timeout=timeout, + country=country, + include_favicon=include_favicon, + **kwargs, + ) sources = response_dict.get("results", []) context = [{"url": source["url"], "content": source["content"]} for source in sources] return json.dumps(get_max_items_from_list(context, max_tokens)) - async def qna_search(self, - query: str, - search_depth: Literal["basic", "advanced"] = "advanced", - topic: Literal["general", "news", "finance"] = "general", - days: int = 7, - max_results: int = 5, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - timeout: int = 60, - country: str = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> str: + async def qna_search( + self, + query: str, + search_depth: Literal["basic", "advanced"] = "advanced", + topic: Literal["general", "news", "finance"] = "general", + days: int = 7, + max_results: int = 5, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + timeout: int = 60, + country: Optional[str] = None, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> str: """ Q&A search method. Search depth is advanced by default to get the best answer. """ timeout = min(timeout, 120) - response_dict = await self._search(query, - search_depth=search_depth, - topic=topic, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_raw_content=False, - include_images=False, - include_answer=True, - timeout = timeout, - country=country, - include_favicon=include_favicon, - **kwargs, - ) + response_dict = await self._search( + query, + search_depth=search_depth, + topic=topic, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_raw_content=False, + include_images=False, + include_answer=True, + timeout=timeout, + country=country, + include_favicon=include_favicon, + **kwargs, + ) return response_dict.get("answer", "") - async def get_company_info(self, - query: str, - search_depth: Literal["basic", "advanced"] = "advanced", - max_results: int = 5, - timeout: int = 60, - country: str = None, - ) -> Sequence[dict]: - """ Company information search method. Search depth is advanced by default to get the best answer. """ + async def get_company_info( + self, + query: str, + search_depth: Literal["basic", "advanced"] = "advanced", + max_results: int = 5, + timeout: int = 60, + country: Optional[str] = None, + ) -> Sequence[dict]: + """Company information search method. Search depth is advanced by default to get the best answer.""" timeout = min(timeout, 120) async def _perform_search(topic: str): - return await self._search(query, - search_depth=search_depth, - topic=topic, - max_results=max_results, - include_answer=False, - timeout = timeout, - country=country) + return await self._search( + query, + search_depth=search_depth, + topic=topic, # type: ignore + max_results=max_results, + include_answer=False, + timeout=timeout, + country=country + ) all_results = [] for data in await asyncio.gather(*[_perform_search(topic) for topic in self._company_info_tags]): diff --git a/tavily/config.py b/tavily/config.py index f24593b..3d990da 100644 --- a/tavily/config.py +++ b/tavily/config.py @@ -5,8 +5,26 @@ # Create a type that represents all allowed categories AllowedCategory = Literal[ - "Documentation", "Blog", "Blogs", "Community", "About", "Contact", - "Privacy", "Terms", "Status", "Pricing", "Enterprise", "Careers", - "E-Commerce", "Authentication", "Developer", "Developers", "Solutions", - "Partners", "Downloads", "Media", "Events", "People" + "Documentation", + "Blog", + "Blogs", + "Community", + "About", + "Contact", + "Privacy", + "Terms", + "Status", + "Pricing", + "Enterprise", + "Careers", + "E-Commerce", + "Authentication", + "Developer", + "Developers", + "Solutions", + "Partners", + "Downloads", + "Media", + "Events", + "People", ] diff --git a/tavily/errors.py b/tavily/errors.py index 9342980..7da5404 100644 --- a/tavily/errors.py +++ b/tavily/errors.py @@ -1,5 +1,3 @@ -from typing import List, Dict, Any, Optional - class UsageLimitExceededError(Exception): def __init__(self, message: str): super().__init__(message) diff --git a/tavily/hybrid_rag/__init__.py b/tavily/hybrid_rag/__init__.py index f77f238..7149a26 100644 --- a/tavily/hybrid_rag/__init__.py +++ b/tavily/hybrid_rag/__init__.py @@ -1 +1,5 @@ -from .hybrid_rag import TavilyHybridClient \ No newline at end of file +from .hybrid_rag import TavilyHybridClient + +__all__ = [ + "TavilyHybridClient" +] \ No newline at end of file diff --git a/tavily/hybrid_rag/hybrid_rag.py b/tavily/hybrid_rag/hybrid_rag.py index c3ea97c..bc7dfaa 100644 --- a/tavily/hybrid_rag/hybrid_rag.py +++ b/tavily/hybrid_rag/hybrid_rag.py @@ -1,14 +1,13 @@ -import os -from typing import Union, Optional, Literal - from tavily import TavilyClient +from typing import Union, Optional, Literal, Callable try: import cohere co = cohere.Client() -except: +except ImportError: co = None + def _validate_index(client): """ Check that the index specified by the parameters exists and is a valid vector search index. @@ -19,68 +18,77 @@ def _validate_index(client): """ index_exists = False for index in client.collection.list_search_indexes(): - if index['name'] != client.index: + if index["name"] != client.index: continue - - if index['type'] != 'vectorSearch': - raise ValueError(f"Index '{client.index}' exists but is not of type " - "'vectorSearch'.") - + + if index["type"] != "vectorSearch": + raise ValueError( + f"Index '{client.index}' exists but is not of type 'vectorSearch'." + ) + field_exists = False - for field in index['latestDefinition']['fields']: - if field['path'] != client.embeddings_field: + for field in index["latestDefinition"]["fields"]: + if field["path"] != client.embeddings_field: continue - - if field['type'] != 'vector': - raise ValueError(f"Field '{client.embeddings_field}' exists " - "but is not of type 'vector'.") - elif field['similarity'] != 'cosine': - raise ValueError(f"Field '{client.embeddings_field}' exists but has " - f"similarity '{field['similarity']}' instead of 'cosine'.") - + + if field["type"] != "vector": + raise ValueError( + f"Field '{client.embeddings_field}' exists but is not of type 'vector'." + ) + elif field["similarity"] != "cosine": + raise ValueError( + f"Field '{client.embeddings_field}' exists but has similarity '{field['similarity']}' instead of 'cosine'." + ) + field_exists = True break - + if not field_exists: - raise ValueError(f"Field '{client.embeddings_field}' does not exist in " - "index '{client.index}'.") - + raise ValueError( + f"Field '{client.embeddings_field}' does not exist in index '{{client.index}}'." + ) + index_exists = True - + if not index_exists: raise ValueError(f"Index '{client.index}' does not exist.") + def _cohere_embed(texts, type): - return co.embed( - model='embed-english-v3.0', - texts=texts, - input_type=type + return co.embed( # type: ignore + model="embed-english-v3.0", texts=texts, input_type=type ).embeddings + def _cohere_rerank(query, documents, top_n): - response = co.rerank(model='rerank-english-v3.0', query=query, - documents=[doc['content'] for doc in documents], top_n=top_n) - + response = co.rerank( # type: ignore + model="rerank-english-v3.0", + query=query, + documents=[doc["content"] for doc in documents], + top_n=top_n, + ) + return [ - documents[result.index] | {'score': result.relevance_score} + documents[result.index] | {"score": result.relevance_score} for result in response.results ] -class TavilyHybridClient(): + +class TavilyHybridClient: def __init__( - self, - api_key: Union[str, None], - db_provider: Literal['mongodb'], - collection, - index: str, - embeddings_field: str = 'embeddings', - content_field: str = 'content', - embedding_function: Optional[callable] = None, - ranking_function: Optional[callable] = None - ): - ''' + self, + api_key: Union[str, None], + db_provider: Literal["mongodb"], + collection, + index: str, + embeddings_field: str = "embeddings", + content_field: str = "content", + embedding_function: Optional[Callable] = None, + ranking_function: Optional[Callable] = None, + ): + """ A client for performing hybrid RAG using both the Tavily API and a local database collection. - + Parameters: api_key (str): The Tavily API key. If this is set to None, it will be loaded from the environment variable TAVILY_API_KEY. db_provider (str): The database provider. Currently only 'mongodb' is supported. @@ -88,31 +96,44 @@ def __init__( index (str): The name of the collection's vector search index. embeddings_field (str): The name of the field in the collection that contains the embeddings. content_field (str): The name of the field in the collection that contains the content. - embedding_function (callable): If provided, this function will be used to generate embeddings for the search query and documents. - ranking_function (callable): If provided, this function will be used to rerank the combined results. - ''' - + embedding_function (Callable): If provided, this function will be used to generate embeddings for the search query and documents. + ranking_function (Callable): If provided, this function will be used to rerank the combined results. + """ + self.tavily = TavilyClient(api_key) - - if db_provider != 'mongodb': - raise ValueError("Only MongoDB is currently supported as a database provider.") - + + if db_provider != "mongodb": + raise ValueError( + "Only MongoDB is currently supported as a database provider." + ) + self.collection = collection self.index = index self.embeddings_field = embeddings_field self.content_field = content_field - - self.embedding_function = _cohere_embed if embedding_function is None else embedding_function - self.ranking_function = _cohere_rerank if ranking_function is None else ranking_function - + + self.embedding_function = ( + _cohere_embed if embedding_function is None else embedding_function + ) + self.ranking_function = ( + _cohere_rerank if ranking_function is None else ranking_function + ) + _validate_index(self) - def search(self, query, max_results=10, max_local=None, max_foreign=None, - save_foreign=False, **kwargs): - ''' + def search( + self, + query, + max_results=10, + max_local=None, + max_foreign=None, + save_foreign=False, + **kwargs, + ): + """ Return results for the given query from both the tavily API (foreign) and the specified mongo collection (local). - + Parameters: query (str): The query to search for. max_results (int): The maximum number of results to return. @@ -120,86 +141,97 @@ def search(self, query, max_results=10, max_local=None, max_foreign=None, max_foreign (int): The maximum number of foreign results to return. save_foreign (bool or function): Whether to save the foreign results in the collection. If a function is provided, it will be used to transform the foreign results before saving. - ''' + """ if max_local is None: max_local = max_results - + if max_foreign is None: max_foreign = max_results - query_embeddings = self.embedding_function([query], 'search_query')[0] + query_embeddings = self.embedding_function([query], "search_query")[0] # type: ignore # Search the local collection - local_results = list(self.collection.aggregate([ - { - "$vectorSearch": { - "index": self.index, - "path": self.embeddings_field, - "queryVector": query_embeddings, - "numCandidates": max_local + 3, - "limit": max_local - } - }, - { - "$project": { - "_id": 0, - "content": f"${self.content_field}", - "score": { - "$meta": "vectorSearchScore" + local_results = list( + self.collection.aggregate( + [ + { + "$vectorSearch": { + "index": self.index, + "path": self.embeddings_field, + "queryVector": query_embeddings, + "numCandidates": max_local + 3, + "limit": max_local, + } }, - "origin": "local" - } - } - ])) + { + "$project": { + "_id": 0, + "content": f"${self.content_field}", + "score": {"$meta": "vectorSearchScore"}, + "origin": "local", + } + }, + ] + ) + ) # Search using tavily if max_foreign > 0: - foreign_results = self.tavily.search(query, max_results=max_foreign, **kwargs)['results'] + foreign_results = self.tavily.search( + query, max_results=max_foreign, **kwargs + )["results"] else: foreign_results = [] # Combine the results projected_foreign_results = [ { - 'content': result['content'], - 'score': result['score'], - 'origin': 'foreign' + "content": result["content"], + "score": result["score"], + "origin": "foreign", } for result in foreign_results ] - + combined_results = local_results + projected_foreign_results - + if len(combined_results) == 0: return [] # Sort the combined results - combined_results = self.ranking_function(query, combined_results, max_results) + combined_results = self.ranking_function( + query, combined_results, max_results + ) if len(combined_results) > max_results: combined_results = combined_results[:max_results] # Can't use 'not save_foreign' because save_foreign is not necessarily a boolean - if max_foreign > 0 and save_foreign != False: + if max_foreign > 0 and save_foreign != False: # noqa: E712 documents = [] - embeddings = self.embedding_function([result['content'] for result in foreign_results], 'search_document') + embeddings = self.embedding_function( + [result["content"] for result in foreign_results], + "search_document", + ) for i, result in enumerate(foreign_results): - result['embeddings'] = embeddings[i] - - if save_foreign == True: + result["embeddings"] = embeddings[i] # type: ignore + + if save_foreign == True: # noqa: E712 # No custom function provided, save as is - documents.append({ - self.content_field: result['content'], - self.embeddings_field: result['embeddings'] - }) + documents.append( + { + self.content_field: result["content"], + self.embeddings_field: result["embeddings"], + } + ) else: # save_foreign is a custom function result = save_foreign(result) if result: documents.append(result) - + # Add all in one call to make the operation atomic self.collection.insert_many(documents) - return combined_results \ No newline at end of file + return combined_results diff --git a/tavily/tavily.py b/tavily/tavily.py index ad55572..4802592 100644 --- a/tavily/tavily.py +++ b/tavily/tavily.py @@ -2,12 +2,13 @@ import json import warnings import os -from typing import Literal, Sequence, Optional, List, Union +from typing import Literal, Sequence, Optional, List, Union, cast from concurrent.futures import ThreadPoolExecutor, as_completed from .utils import get_max_items_from_list from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError from .config import AllowedCategory + class TavilyClient: """ Tavily API client class. @@ -30,34 +31,30 @@ def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, st self.base_url = api_base_url or "https://api.tavily.com" self.api_key = api_key self.proxies = resolved_proxies - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - "X-Client-Source": "tavily-python" - } - - def _search(self, - query: str, - search_depth: Literal["basic", "advanced"] = None, - topic: Literal["general", "news", "finance"] = None, - time_range: Literal["day", "week", "month", "year"] = None, - days: int = None, - max_results: int = None, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - include_answer: Union[bool, Literal["basic", "advanced"]] = None, - include_raw_content: Union[bool, Literal["markdown", "text"]] = None, - include_images: bool = None, - timeout: int = 60, - country: str = None, - auto_parameters: bool = None, - include_favicon: bool = None, - **kwargs - ) -> dict: + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", "X-Client-Source": "tavily-python"} + + def _search( + self, + query: str, + search_depth: Optional[Literal["basic", "advanced"]] = None, + topic: Optional[Literal["general", "news", "finance"]] = None, + time_range: Optional[Literal["day", "week", "month", "year"]] = None, + days: Optional[int] = None, + max_results: Optional[int] = None, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + include_answer: Optional[Union[bool, Literal["basic", "advanced"]]] = None, + include_raw_content: Optional[Union[bool, Literal["markdown", "text"]]] = None, + include_images: Optional[bool] = None, + timeout: int = 60, + country: Optional[str] = None, + auto_parameters: Optional[bool] = None, + include_favicon: Optional[bool] = None, + **kwargs, + ) -> dict: """ Internal search method to send the request to the API. """ - data = { "query": query, "search_depth": search_depth, @@ -98,74 +95,78 @@ def _search(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) - else: - raise response.raise_for_status() - - - def search(self, - query: str, - search_depth: Literal["basic", "advanced"] = None, - topic: Literal["general", "news", "finance" ] = None, - time_range: Literal["day", "week", "month", "year"] = None, - days: int = None, - max_results: int = None, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - include_answer: Union[bool, Literal["basic", "advanced"]] = None, - include_raw_content: Union[bool, Literal["markdown", "text"]] = None, - include_images: bool = None, - timeout: int = 60, - country: str = None, - auto_parameters: bool = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + def search( + self, + query: str, + search_depth: Optional[Literal["basic", "advanced"]] = None, + topic: Optional[Literal["general", "news", "finance"]] = None, + time_range: Optional[Literal["day", "week", "month", "year"]] = None, + days: Optional[int] = None, + max_results: Optional[int] = None, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + include_answer: Optional[Union[bool, Literal["basic", "advanced"]]] = None, + include_raw_content: Optional[Union[bool, Literal["markdown", "text"]]] = None, + include_images: Optional[bool] = None, + timeout: int = 60, + country: Optional[str] = None, + auto_parameters: Optional[bool] = None, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> dict: """ Combined search method. """ - timeout = min(timeout, 120) - response_dict = self._search(query, - search_depth=search_depth, - topic=topic, - time_range=time_range, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=include_answer, - include_raw_content=include_raw_content, - include_images=include_images, - timeout=timeout, - country=country, - auto_parameters=auto_parameters, - include_favicon=include_favicon, - **kwargs, - ) - - tavily_results = response_dict.get("results", []) - - response_dict["results"] = tavily_results - - return response_dict - - def _extract(self, - urls: Union[List[str], str], - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs - ) -> dict: + try: + timeout = min(timeout, 120) + response_dict = self._search( + query, + search_depth=search_depth, + topic=topic, + time_range=time_range, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=include_answer, + include_raw_content=include_raw_content, + include_images=include_images, + timeout=timeout, + country=country, + auto_parameters=auto_parameters, + include_favicon=include_favicon, + **kwargs, + ) + + tavily_results = response_dict.get("results", []) + + response_dict["results"] = tavily_results + + return response_dict + except Exception as exc: + raise exc + + def _extract( + self, + urls: Union[List[str], str], + include_images: Optional[bool] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, + ) -> dict: """ - Internal extract method to send the request to the API. + Internal extract method to send the request to the API. """ data = { "urls": urls, @@ -198,35 +199,30 @@ def _extract(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: - raise response.raise_for_status() - - def extract(self, - urls: Union[List[str], str], # Accept a list of URLs or a single URL - include_images: bool = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + def extract( + self, + urls: Union[List[str], str], # Accept a list of URLs or a single URL + include_images: Optional[bool] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> dict: """ Combined extract method. """ timeout = min(timeout, 120) - response_dict = self._extract(urls, - include_images, - extract_depth, - format, - timeout, - include_favicon=include_favicon, - **kwargs) + response_dict = self._extract(urls, include_images, extract_depth, format, timeout, include_favicon=include_favicon, **kwargs) tavily_results = response_dict.get("results", []) failed_results = response_dict.get("failed_results", []) @@ -236,25 +232,26 @@ def extract(self, return response_dict - def _crawl(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs - ) -> dict: + def _crawl( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, + ) -> dict: """ Internal crawl method to send the request to the API. include_favicon: If True, include the favicon in the crawl results. @@ -279,14 +276,13 @@ def _crawl(self, if kwargs: data.update(kwargs) - + data = {k: v for k, v in data.items() if v is not None} - + timeout = min(timeout, 120) try: - response = requests.post( - self.base_url + "/crawl", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) + response = requests.post(self.base_url + "/crawl", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -301,75 +297,79 @@ def _crawl(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: - raise response.raise_for_status() - - def crawl(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - extract_depth: Literal["basic", "advanced"] = None, - format: Literal["markdown", "text"] = None, - timeout: int = 60, - include_favicon: bool = None, - **kwargs - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + def crawl( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + extract_depth: Optional[Literal["basic", "advanced"]] = None, + format: Optional[Literal["markdown", "text"]] = None, + timeout: int = 60, + include_favicon: Optional[bool] = None, + **kwargs, + ) -> dict: """ Combined crawl method. include_favicon: If True, include the favicon in the crawl results. """ timeout = min(timeout, 120) - response_dict = self._crawl(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - categories=categories, - extract_depth=extract_depth, - format=format, - timeout=timeout, - include_favicon=include_favicon, - **kwargs) + response_dict = self._crawl( + url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + categories=categories, + extract_depth=extract_depth, + format=format, + timeout=timeout, + include_favicon=include_favicon, + **kwargs, + ) return response_dict - - def _map(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - timeout: int = 60, - **kwargs - ) -> dict: + + def _map( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + timeout: int = 60, + **kwargs, + ) -> dict: """ Internal map method to send the request to the API. """ @@ -390,14 +390,13 @@ def _map(self, if kwargs: data.update(kwargs) - + data = {k: v for k, v in data.items() if v is not None} timeout = min(timeout, 120) try: - response = requests.post( - self.base_url + "/map", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) + response = requests.post(self.base_url + "/map", data=json.dumps(data), headers=self.headers, timeout=timeout, proxies=self.proxies) except requests.exceptions.Timeout: raise TimeoutError(timeout) @@ -412,67 +411,71 @@ def _map(self, if response.status_code == 429: raise UsageLimitExceededError(detail) - elif response.status_code in [403,432,433]: + elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: raise InvalidAPIKeyError(detail) elif response.status_code == 400: raise BadRequestError(detail) else: - raise response.raise_for_status() - - def map(self, - url: str, - max_depth: int = None, - max_breadth: int = None, - limit: int = None, - instructions: str = None, - select_paths: Sequence[str] = None, - select_domains: Sequence[str] = None, - exclude_paths: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - allow_external: bool = None, - include_images: bool = None, - categories: Sequence[AllowedCategory] = None, - timeout: int = 60, - **kwargs - ) -> dict: + raise cast(Exception, response.raise_for_status()) + + def map( + self, + url: str, + max_depth: Optional[int] = None, + max_breadth: Optional[int] = None, + limit: Optional[int] = None, + instructions: Optional[str] = None, + select_paths: Optional[Sequence[str]] = None, + select_domains: Optional[Sequence[str]] = None, + exclude_paths: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + allow_external: Optional[bool] = None, + include_images: Optional[bool] = None, + categories: Optional[Sequence[AllowedCategory]] = None, + timeout: int = 60, + **kwargs, + ) -> dict: """ Combined map method. - + """ timeout = min(timeout, 120) - response_dict = self._map(url, - max_depth=max_depth, - max_breadth=max_breadth, - limit=limit, - instructions=instructions, - select_paths=select_paths, - select_domains=select_domains, - exclude_paths=exclude_paths, - exclude_domains=exclude_domains, - allow_external=allow_external, - include_images=include_images, - categories=categories, - timeout=timeout, - **kwargs) + response_dict = self._map( + url, + max_depth=max_depth, + max_breadth=max_breadth, + limit=limit, + instructions=instructions, + select_paths=select_paths, + select_domains=select_domains, + exclude_paths=exclude_paths, + exclude_domains=exclude_domains, + allow_external=allow_external, + include_images=include_images, + categories=categories, + timeout=timeout, + **kwargs, + ) return response_dict - def get_search_context(self, - query: str, - search_depth: Literal["basic", "advanced"] = "basic", - topic: Literal["general", "news", "finance"] = "general", - days: int = 7, - max_results: int = 5, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - max_tokens: int = 4000, - timeout: int = 60, - country: str = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> str: + def get_search_context( + self, + query: str, + search_depth: Literal["basic", "advanced"] = "basic", + topic: Literal["general", "news", "finance"] = "general", + days: int = 7, + max_results: int = 5, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + max_tokens: int = 4000, + timeout: int = 60, + country: Optional[str] = None, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> str: """ Get the search context for a query. Useful for getting only related content from retrieved websites without having to deal with context extraction and limitation yourself. @@ -482,95 +485,90 @@ def get_search_context(self, Returns a string of JSON containing the search context up to context limit. """ timeout = min(timeout, 120) - response_dict = self._search(query, - search_depth=search_depth, - topic=topic, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=False, - include_raw_content=False, - include_images=False, - timeout=timeout, - country=country, - include_favicon=include_favicon, - **kwargs, - ) + response_dict = self._search( + query, + search_depth=search_depth, + topic=topic, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=False, + include_raw_content=False, + include_images=False, + timeout=timeout, + country=country, + include_favicon=include_favicon, + **kwargs, + ) sources = response_dict.get("results", []) - context = [{"url": source["url"], "content": source["content"]} - for source in sources] + context = [{"url": source["url"], "content": source["content"]} for source in sources] return json.dumps(get_max_items_from_list(context, max_tokens)) - def qna_search(self, - query: str, - search_depth: Literal["basic", "advanced"] = "advanced", - topic: Literal["general", "news", "finance"] = "general", - days: int = 7, - max_results: int = 5, - include_domains: Sequence[str] = None, - exclude_domains: Sequence[str] = None, - timeout: int = 60, - country: str = None, - include_favicon: bool = None, - **kwargs, # Accept custom arguments - ) -> str: + def qna_search( + self, + query: str, + search_depth: Literal["basic", "advanced"] = "advanced", + topic: Literal["general", "news", "finance"] = "general", + days: int = 7, + max_results: int = 5, + include_domains: Optional[Sequence[str]] = None, + exclude_domains: Optional[Sequence[str]] = None, + timeout: int = 60, + country: Optional[str] = None, + include_favicon: Optional[bool] = None, + **kwargs, # Accept custom arguments + ) -> str: """ Q&A search method. Search depth is advanced by default to get the best answer. """ timeout = min(timeout, 120) - response_dict = self._search(query, - search_depth=search_depth, - topic=topic, - days=days, - max_results=max_results, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_raw_content=False, - include_images=False, - include_answer=True, - timeout=timeout, - country=country, - include_favicon=include_favicon, - **kwargs, - ) + response_dict = self._search( + query, + search_depth=search_depth, + topic=topic, + days=days, + max_results=max_results, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_raw_content=False, + include_images=False, + include_answer=True, + timeout=timeout, + country=country, + include_favicon=include_favicon, + **kwargs, + ) return response_dict.get("answer", "") - def get_company_info(self, - query: str, - search_depth: Literal["basic", - "advanced"] = "advanced", - max_results: int = 5, - timeout: int = 60, - country: str = None, - ) -> Sequence[dict]: - """ Company information search method. Search depth is advanced by default to get the best answer. """ + def get_company_info( + self, + query: str, + search_depth: Literal["basic", "advanced"] = "advanced", + max_results: int = 5, + timeout: int = 60, + country: Optional[str] = None, + ) -> Sequence[dict]: + """Company information search method. Search depth is advanced by default to get the best answer.""" timeout = min(timeout, 120) + def _perform_search(topic): - return self._search(query, - search_depth=search_depth, - topic=topic, - max_results=max_results, - include_answer=False, - timeout=timeout, - country=country) + return self._search(query, search_depth=search_depth, topic=topic, max_results=max_results, include_answer=False, timeout=timeout, country=country) with ThreadPoolExecutor() as executor: # Initiate the search for each topic in parallel - future_to_topic = {executor.submit(_perform_search, topic): topic for topic in - ["news", "general", "finance"]} + future_to_topic = {executor.submit(_perform_search, topic): topic for topic in ["news", "general", "finance"]} all_results = [] # Process the results as they become available for future in as_completed(future_to_topic): data = future.result() - if 'results' in data: - all_results.extend(data['results']) + if "results" in data: + all_results.extend(data["results"]) # Sort all the results by score in descending order and take the top 'max_results' items - sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)[ - :max_results] + sorted_results = sorted(all_results, key=lambda x: x["score"], reverse=True)[:max_results] return sorted_results @@ -583,6 +581,5 @@ class Client(TavilyClient): """ def __init__(self, kwargs): - warnings.warn("Client is deprecated, please use TavilyClient instead", - DeprecationWarning, stacklevel=2) + warnings.warn("Client is deprecated, please use TavilyClient instead", DeprecationWarning, stacklevel=2) super().__init__(kwargs) diff --git a/tavily/utils.py b/tavily/utils.py index 6d4e90d..7664182 100644 --- a/tavily/utils.py +++ b/tavily/utils.py @@ -6,7 +6,7 @@ def get_total_tokens_from_string(string: str, encoding_name: str = DEFAULT_MODEL_ENCODING) -> int: """ - Get total amount of tokens from string using the specified encoding (based on openai compute) + Get total amount of tokens from string using the specified encoding (based on openai compute) """ encoding = tiktoken.encoding_for_model(encoding_name) tokens = encoding.encode(string) @@ -15,7 +15,7 @@ def get_total_tokens_from_string(string: str, encoding_name: str = DEFAULT_MODEL def get_max_tokens_from_string(string: str, max_tokens: int, encoding_name: str = DEFAULT_MODEL_ENCODING) -> str: """ - Extract max tokens from string using the specified encoding (based on openai compute) + Extract max tokens from string using the specified encoding (based on openai compute) """ encoding = tiktoken.encoding_for_model(encoding_name) tokens = encoding.encode(string) @@ -23,9 +23,9 @@ def get_max_tokens_from_string(string: str, max_tokens: int, encoding_name: str return b"".join(token_bytes).decode() -def get_max_items_from_list(data: Sequence[dict], max_tokens: int = DEFAULT_MAX_TOKENS) -> List[Dict[str,str]]: +def get_max_items_from_list(data: Sequence[dict], max_tokens: int = DEFAULT_MAX_TOKENS) -> List[Dict[str, str]]: """ - Get max items from list of items based on defined max tokens (based on openai compute) + Get max items from list of items based on defined max tokens (based on openai compute) """ result = [] current_tokens = 0