diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index c8d06a9cd5..5f5ae37243 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -66,6 +66,8 @@ jobs: NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}" NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}" FIRECRAWL_API_KEY: "${{ secrets.FIRECRAWL_API_KEY }}" + ASKNEWS_CLIENT_ID: "${{ secrets.ASKNEWS_CLIENT_ID }}" + ASKNEWS_CLIENT_SECRET: "${{ secrets.ASKNEWS_CLIENT_SECRET }}" run: | source venv/bin/activate pytest --fast-test-mode ./test diff --git a/.github/workflows/pytest_package.yml b/.github/workflows/pytest_package.yml index 7c6b6a452c..6d1889f922 100644 --- a/.github/workflows/pytest_package.yml +++ b/.github/workflows/pytest_package.yml @@ -45,6 +45,8 @@ jobs: NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}" NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}" FIRECRAWL_API_KEY: "${{ secrets.FIRECRAWL_API_KEY }}" + ASKNEWS_CLIENT_ID: "${{ secrets.ASKNEWS_CLIENT_ID }}" + ASKNEWS_CLIENT_SECRET: "${{ secrets.ASKNEWS_CLIENT_SECRET }}" run: poetry run pytest --fast-test-mode test/ pytest_package_llm_test: @@ -79,6 +81,8 @@ jobs: NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}" NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}" FIRECRAWL_API_KEY: "${{ secrets.FIRECRAWL_API_KEY }}" + ASKNEWS_CLIENT_ID: "${{ secrets.ASKNEWS_CLIENT_ID }}" + ASKNEWS_CLIENT_SECRET: "${{ secrets.ASKNEWS_CLIENT_SECRET }}" run: poetry run pytest --llm-test-only test/ pytest_package_very_slow_test: @@ -113,4 +117,6 @@ jobs: NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}" NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}" FIRECRAWL_API_KEY: "${{ secrets.FIRECRAWL_API_KEY }}" + ASKNEWS_CLIENT_ID: "${{ secrets.ASKNEWS_CLIENT_ID }}" + ASKNEWS_CLIENT_SECRET: "${{ secrets.ASKNEWS_CLIENT_SECRET }}" run: poetry run pytest --very-slow-test-only test/ diff --git a/camel/toolkits/__init__.py b/camel/toolkits/__init__.py index f703d5a7c9..d1658959c1 100644 --- a/camel/toolkits/__init__.py +++ b/camel/toolkits/__init__.py @@ -25,15 +25,16 @@ from .search_toolkit import SearchToolkit, SEARCH_FUNCS from .weather_toolkit import WeatherToolkit, WEATHER_FUNCS from .dalle_toolkit import DalleToolkit, DALLE_FUNCS +from .ask_news_toolkit import AskNewsToolkit, AsyncAskNewsToolkit +from .linkedin_toolkit import LinkedInToolkit +from .reddit_toolkit import RedditToolkit from .base import BaseToolkit from .google_maps_toolkit import GoogleMapsToolkit from .code_execution import CodeExecutionToolkit from .github_toolkit import GithubToolkit from .google_scholar_toolkit import GoogleScholarToolkit from .arxiv_toolkit import ArxivToolkit -from .linkedin_toolkit import LinkedInToolkit -from .reddit_toolkit import RedditToolkit from .slack_toolkit import SlackToolkit from .twitter_toolkit import TwitterToolkit from .open_api_toolkit import OpenAPIToolkit @@ -59,6 +60,8 @@ 'LinkedInToolkit', 'RedditToolkit', 'CodeExecutionToolkit', + 'AskNewsToolkit', + 'AsyncAskNewsToolkit', 'GoogleScholarToolkit', 'ArxivToolkit', 'MATH_FUNCS', diff --git a/camel/toolkits/ask_news_toolkit.py b/camel/toolkits/ask_news_toolkit.py new file mode 100644 index 0000000000..5e319a32bf --- /dev/null +++ b/camel/toolkits/ask_news_toolkit.py @@ -0,0 +1,653 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import os +from datetime import datetime +from typing import List, Literal, Optional, Tuple, Union + +from camel.toolkits import FunctionTool +from camel.toolkits.base import BaseToolkit + + +def _process_response( + response, return_type: str +) -> Union[str, dict, Tuple[str, dict]]: + r"""Process the response based on the specified return type. + + This helper method processes the API response and returns the content + in the specified format, which could be a string, a dictionary, or + both. + + Args: + response: The response object returned by the API call. + return_type (str): Specifies the format of the return value. It + can be "string" to return the response as a string, "dicts" to + return it as a dictionary, or "both" to return both formats as + a tuple. + + Returns: + Union[str, dict, Tuple[str, dict]]: The processed response, + formatted according to the return_type argument. If "string", + returns the response as a string. If "dicts", returns the + response as a dictionary. If "both", returns a tuple + containing both formats. + + Raises: + ValueError: If the return_type provided is invalid. + """ + if return_type == "string": + return response.as_string + elif return_type == "dicts": + return response.as_dicts + elif return_type == "both": + return (response.as_string, response.as_dicts) + else: + raise ValueError(f"Invalid return_type: {return_type}") + + +class AskNewsToolkit(BaseToolkit): + r"""A class representing a toolkit for interacting with the AskNews API. + + This class provides methods for fetching news, stories, and other content + based on user queries using the AskNews API. + """ + + def __init__(self): + r"""Initialize the AskNewsToolkit with API clients.The API keys and + credentials are retrieved from environment variables. + """ + from asknews_sdk import AskNewsSDK # type: ignore[import] + + client_id = os.environ.get("ASKNEWS_CLIENT_ID") + client_secret = os.environ.get("ASKNEWS_CLIENT_SECRET") + + if client_id and client_secret: + self.asknews_client = AskNewsSDK(client_id, client_secret) + else: + self.asknews_client = None + + def get_news( + self, + query: str, + n_articles: int = 10, + return_type: Literal["string", "dicts", "both"] = "string", + method: Literal["nl", "kw"] = "kw", + ) -> Union[str, dict, Tuple[str, dict]]: + r"""Fetch news or stories based on a user query. + + Args: + query (str): The search query for fetching relevant news. + n_articles (int): Number of articles to include in the response. + (default: :obj:`10`) + return_type (Literal["string", "dicts", "both"]): The format of the + return value. (default: :obj:`"string"`) + method (Literal["nl", "kw"]): The search method, either "nl" for + natural language or "kw" for keyword search. (default: + :obj:`"kw"`) + + Returns: + Union[str, dict, Tuple[str, dict]]: A string, dictionary, + or both containing the news or story content, or error message + if the process fails. + """ + try: + response = self.asknews_client.news.search_news( + query=query, + n_articles=n_articles, + return_type=return_type, + method=method, + ) + + return _process_response(response, return_type) + + except Exception as e: + return f"Got error: {e}" + + def get_stories( + self, + query: str, + categories: List[ + Literal[ + 'Politics', + 'Economy', + 'Finance', + 'Science', + 'Technology', + 'Sports', + 'Climate', + 'Environment', + 'Culture', + 'Entertainment', + 'Business', + 'Health', + 'International', + ] + ], + reddit: int = 3, + expand_updates: bool = True, + max_updates: int = 2, + max_articles: int = 10, + ) -> Union[dict, str]: + r"""Fetch stories based on the provided parameters. + + Args: + query (str): The search query for fetching relevant stories. + categories (list): The categories to filter stories by. + reddit (int): Number of Reddit threads to include. + (default: :obj:`3`) + expand_updates (bool): Whether to include detailed updates. + (default: :obj:`True`) + max_updates (int): Maximum number of recent updates per story. + (default: :obj:`2`) + max_articles (int): Maximum number of articles associated with + each update. (default: :obj:`10`) + + Returns: + Unio[dict, str]: A dictionary containing the stories and their + associated data, or error message if the process fails. + """ + try: + response = self.asknews_client.stories.search_stories( + query=query, + categories=categories, + reddit=reddit, + expand_updates=expand_updates, + max_updates=max_updates, + max_articles=max_articles, + ) + + # Collect only the headline and story content from the updates + stories_data = { + "stories": [ + { + "headline": story.updates[0].headline, + "updates": [ + { + "headline": update.headline, + "story": update.story, + } + for update in story.updates[:max_updates] + ], + } + for story in response.stories + ] + } + + return stories_data + + except Exception as e: + return f"Got error: {e}" + + def get_web_search( + self, + queries: List[str], + return_type: Literal["string", "dicts", "both"] = "string", + ) -> Union[str, dict, Tuple[str, dict]]: + r"""Perform a live web search based on the given queries. + + Args: + queries (List[str]): A list of search queries. + return_type (Literal["string", "dicts", "both"]): The format of the + return value. (default: :obj:`"string"`) + + Returns: + Union[str, dict, Tuple[str, dict]]: A string, + dictionary, or both containing the search results, or + error message if the process fails. + """ + try: + response = self.asknews_client.chat.live_web_search( + queries=queries + ) + + return _process_response(response, return_type) + + except Exception as e: + return f"Got error: {e}" + + def search_reddit( + self, + keywords: List[str], + n_threads: int = 5, + return_type: Literal["string", "dicts", "both"] = "string", + method: Literal["nl", "kw"] = "kw", + ) -> Union[str, dict, Tuple[str, dict]]: + r"""Search Reddit based on the provided keywords. + + Args: + keywords (List[str]): The keywords to search for on Reddit. + n_threads (int): Number of Reddit threads to summarize and return. + (default: :obj:`5`) + return_type (Literal["string", "dicts", "both"]): The format of the + return value. (default: :obj:`"string"`) + method (Literal["nl", "kw"]): The search method, either "nl" for + natural language or "kw" for keyword search. + (default::obj:`"kw"`) + + Returns: + Union[str, dict, Tuple[str, dict]]: The Reddit search + results as a string, dictionary, or both, or error message if + the process fails. + """ + try: + response = self.asknews_client.news.search_reddit( + keywords=keywords, n_threads=n_threads, method=method + ) + + return _process_response(response, return_type) + + except Exception as e: + return f"Got error: {e}" + + def query_finance( + self, + asset: Literal[ + 'bitcoin', + 'ethereum', + 'cardano', + 'uniswap', + 'ripple', + 'solana', + 'polkadot', + 'polygon', + 'chainlink', + 'tether', + 'dogecoin', + 'monero', + 'tron', + 'binance', + 'aave', + 'tesla', + 'microsoft', + 'amazon', + ], + metric: Literal[ + 'news_positive', + 'news_negative', + 'news_total', + 'news_positive_weighted', + 'news_negative_weighted', + 'news_total_weighted', + ] = "news_positive", + return_type: Literal["list", "string"] = "string", + date_from: Optional[datetime] = None, + date_to: Optional[datetime] = None, + ) -> Union[list, str]: + r"""Fetch asset sentiment data for a given asset, metric, and date + range. + + Args: + asset (Literal): The asset for which to fetch sentiment data. + metric (Literal): The sentiment metric to analyze. + return_type (Literal["list", "string"]): The format of the return + value. (default: :obj:`"string"`) + date_from (datetime, optional): The start date and time for the + data in ISO 8601 format. + date_to (datetime, optional): The end date and time for the data + in ISO 8601 format. + + Returns: + Union[list, str]: A list of dictionaries containing the datetime + and value or a string describing all datetime and value pairs + for providing quantified time-series data for news sentiment + on topics of interest, or an error message if the process + fails. + """ + try: + response = self.asknews_client.analytics.get_asset_sentiment( + asset=asset, + metric=metric, + date_from=date_from, + date_to=date_to, + ) + + time_series_data = response.data.timeseries + + if return_type == "list": + return time_series_data + elif return_type == "string": + header = ( + f"This is the sentiment analysis for '{asset}' based " + + f"on the '{metric}' metric from {date_from} to {date_to}" + + ". The values reflect the aggregated sentiment from news" + + " sources for each given time period.\n" + ) + descriptive_text = "\n".join( + [ + f"On {entry.datetime}, the sentiment value was " + f"{entry.value}." + for entry in time_series_data + ] + ) + return header + descriptive_text + + except Exception as e: + return f"Got error: {e}" + + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions + in the toolkit. + + Returns: + List[FunctionTool]: A list of FunctionTool objects representing + the functions in the toolkit. + """ + return [ + FunctionTool(self.get_news), + FunctionTool(self.get_stories), + FunctionTool(self.get_web_search), + FunctionTool(self.search_reddit), + FunctionTool(self.query_finance), + ] + + +class AsyncAskNewsToolkit(BaseToolkit): + r"""A class representing a toolkit for interacting with the AskNews API + asynchronously. + + This class provides methods for fetching news, stories, and other + content based on user queries using the AskNews API. + """ + + def __init__(self): + r"""Initialize the AsyncAskNewsToolkit with API clients.The API keys + and credentials are retrieved from environment variables. + """ + from asknews_sdk import AsyncAskNewsSDK # type: ignore[import] + + client_id = os.environ.get("ASKNEWS_CLIENT_ID") + client_secret = os.environ.get("ASKNEWS_CLIENT_SECRET") + + if client_id and client_secret: + self.asknews_client = AsyncAskNewsSDK(client_id, client_secret) + else: + self.asknews_client = None + + async def get_news( + self, + query: str, + n_articles: int = 10, + return_type: Literal["string", "dicts", "both"] = "string", + method: Literal["nl", "kw"] = "kw", + ) -> Union[str, dict, Tuple[str, dict]]: + r"""Fetch news or stories based on a user query. + + Args: + query (str): The search query for fetching relevant news or + stories. + n_articles (int): Number of articles to include in the response. + (default: :obj:10) + return_type (Literal["string", "dicts", "both"]): The format of the + return value. (default: :obj:"string") + method (Literal["nl", "kw"]): The search method, either "nl" for + natural language or "kw" for keyword search. (default: + :obj:"kw") + + Returns: + Union[str, dict, Tuple[str, dict]]: A string, + dictionary, or both containing the news or story content, or + error message if the process fails. + """ + try: + response = await self.asknews_client.news.search_news( + query=query, + n_articles=n_articles, + return_type=return_type, + method=method, + ) + + return _process_response(response, return_type) + + except Exception as e: + return f"Got error: {e}" + + async def get_stories( + self, + query: str, + categories: List[ + Literal[ + 'Politics', + 'Economy', + 'Finance', + 'Science', + 'Technology', + 'Sports', + 'Climate', + 'Environment', + 'Culture', + 'Entertainment', + 'Business', + 'Health', + 'International', + ] + ], + reddit: int = 3, + expand_updates: bool = True, + max_updates: int = 2, + max_articles: int = 10, + ) -> Union[dict, str]: + r"""Fetch stories based on the provided parameters. + + Args: + query (str): The search query for fetching relevant stories. + categories (list): The categories to filter stories by. + reddit (int): Number of Reddit threads to include. + (default: :obj:`3`) + expand_updates (bool): Whether to include detailed updates. + (default: :obj:`True`) + max_updates (int): Maximum number of recent updates per story. + (default: :obj:`2`) + max_articles (int): Maximum number of articles associated with + each update. (default: :obj:`10`) + + Returns: + Unio[dict, str]: A dictionary containing the stories and their + associated data, or error message if the process fails. + """ + try: + response = await self.asknews_client.stories.search_stories( + query=query, + categories=categories, + reddit=reddit, + expand_updates=expand_updates, + max_updates=max_updates, + max_articles=max_articles, + ) + + # Collect only the headline and story content from the updates + stories_data = { + "stories": [ + { + "headline": story.updates[0].headline, + "updates": [ + { + "headline": update.headline, + "story": update.story, + } + for update in story.updates[:max_updates] + ], + } + for story in response.stories + ] + } + + return stories_data + + except Exception as e: + return f"Got error: {e}" + + async def get_web_search( + self, + queries: List[str], + return_type: Literal["string", "dicts", "both"] = "string", + ) -> Union[str, dict, Tuple[str, dict]]: + r"""Perform a live web search based on the given queries. + + Args: + queries (List[str]): A list of search queries. + return_type (Literal["string", "dicts", "both"]): The format of the + return value. (default: :obj:`"string"`) + + Returns: + Union[str, dict, Tuple[str, dict]]: A string, + dictionary, or both containing the search results, or + error message if the process fails. + """ + try: + response = await self.asknews_client.chat.live_web_search( + queries=queries + ) + + return _process_response(response, return_type) + + except Exception as e: + return f"Got error: {e}" + + async def search_reddit( + self, + keywords: List[str], + n_threads: int = 5, + return_type: Literal["string", "dicts", "both"] = "string", + method: Literal["nl", "kw"] = "kw", + ) -> Union[str, dict, Tuple[str, dict]]: + r"""Search Reddit based on the provided keywords. + + Args: + keywords (list): The keywords to search for on Reddit. + n_threads (int): Number of Reddit threads to summarize and return. + (default: :obj:5) + return_type (Literal["string", "dicts", "both"]): The format of the + return value. (default: :obj:"string") + method (Literal["nl", "kw"]): The search method, either "nl" for + natural language or "kw" for keyword search. + (default::obj:"kw") + + Returns: + Union[str, dict, Tuple[str, dict]]: The Reddit search + results as a string, dictionary, or both, or error message if + the process fails. + """ + try: + response = await self.asknews_client.news.search_reddit( + keywords=keywords, n_threads=n_threads, method=method + ) + + return _process_response(response, return_type) + + except Exception as e: + return f"Got error: {e}" + + async def query_finance( + self, + asset: Literal[ + 'bitcoin', + 'ethereum', + 'cardano', + 'uniswap', + 'ripple', + 'solana', + 'polkadot', + 'polygon', + 'chainlink', + 'tether', + 'dogecoin', + 'monero', + 'tron', + 'binance', + 'aave', + 'tesla', + 'microsoft', + 'amazon', + ], + metric: Literal[ + 'news_positive', + 'news_negative', + 'news_total', + 'news_positive_weighted', + 'news_negative_weighted', + 'news_total_weighted', + ] = "news_positive", + return_type: Literal["list", "string"] = "string", + date_from: Optional[datetime] = None, + date_to: Optional[datetime] = None, + ) -> Union[list, str]: + r"""Fetch asset sentiment data for a given asset, metric, and date + range. + + Args: + asset (Literal): The asset for which to fetch sentiment data. + metric (Literal): The sentiment metric to analyze. + return_type (Literal["list", "string"]): The format of the return + value. (default: :obj:`"string"`) + date_from (datetime, optional): The start date and time for the + data in ISO 8601 format. + date_to (datetime, optional): The end date and time for the data + in ISO 8601 format. + + Returns: + Union[list, str]: A list of dictionaries containing the datetime + and value or a string describing all datetime and value pairs + for providing quantified time-series data for news sentiment + on topics of interest, or an error message if the process + fails. + """ + try: + response = await self.asknews_client.analytics.get_asset_sentiment( + asset=asset, + metric=metric, + date_from=date_from, + date_to=date_to, + ) + + time_series_data = response.data.timeseries + + if return_type == "list": + return time_series_data + elif return_type == "string": + header = ( + f"This is the sentiment analysis for '{asset}' based " + + f"on the '{metric}' metric from {date_from} to {date_to}" + + ". The values reflect the aggregated sentiment from news" + + " sources for each given time period.\n" + ) + descriptive_text = "\n".join( + [ + f"On {entry.datetime}, the sentiment value was " + f"{entry.value}." + for entry in time_series_data + ] + ) + return header + descriptive_text + + except Exception as e: + return f"Got error: {e}" + + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of FunctionTool objects representing the functions + in the toolkit. + + Returns: + List[FunctionTool]: A list of FunctionTool objects representing + the functions in the toolkit. + """ + return [ + FunctionTool(self.get_news), + FunctionTool(self.get_stories), + FunctionTool(self.get_web_search), + FunctionTool(self.search_reddit), + FunctionTool(self.query_finance), + ] + + +ASKNEWS_FUNCS: List[FunctionTool] = AskNewsToolkit().get_tools() +ASYNC_ASKNEWS_FUNCS: List[FunctionTool] = AsyncAskNewsToolkit().get_tools() diff --git a/examples/tool_call/ask_news_toolkit_example.py b/examples/tool_call/ask_news_toolkit_example.py new file mode 100644 index 0000000000..705681533e --- /dev/null +++ b/examples/tool_call/ask_news_toolkit_example.py @@ -0,0 +1,54 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +from camel.toolkits import AskNewsToolkit + +ask_news = AskNewsToolkit() + +news_output = ask_news.get_news(query="camel-ai") +print(news_output[:1000]) + +""" +=============================================================================== + +[1]: +Title: Robot Creates Painting of Alan Turing for $130,000 to $195,000 Auction +Summary: A robot named 'Ai-Da' created a painting of Alan Turing, a British +mathematician and father of modern computer science, using artificial +intelligence technology. The painting is part of an online auction that will +take place from October 31 to November 7, and is estimated to be worth between +$130,000 and $195,000. Ai-Da is a highly advanced robot that has arms and a +face that resembles a human, complete with brown hair. The robot was created +in 2019 by a team led by Aidan Meller, an art dealer and founder of Ai-Da +Robot Studio, in collaboration with experts in artificial intelligence from +the universities of Oxford and Birmingham. Ai-Da uses AI to create paintings +or sculptures, and has cameras in its eyes and electronic arms. According to +Ai-Da, 'Through my works on Alan Turing, I celebrate his achievements and +contributions to the development of computing and artificial intelligence.' T +=============================================================================== +""" + +story_output = ask_news.get_stories( + query="camel-ai", categories=["Technology"] +) +print(story_output) + +web_search_output = ask_news.get_web_search(queries=["camel-ai"]) +print(web_search_output) + +reddit_output = ask_news.search_reddit(keywords=["camel-ai", "multi-agent"]) +print(reddit_output) + +finance_output = ask_news.finance_query(asset="bitcoin") +print(finance_output) diff --git a/poetry.lock b/poetry.lock index 7f067dcd15..1f72a4ffb6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -383,6 +383,40 @@ pdfminer-six = "*" PyPDF2 = "*" scikit-learn = "*" +[[package]] +name = "asgiref" +version = "3.8.1" +description = "ASGI specs, helper code, and adapters" +optional = true +python-versions = ">=3.8" +files = [ + {file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"}, + {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} + +[package.extras] +tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] + +[[package]] +name = "asknews" +version = "0.7.47" +description = "Python SDK for AskNews" +optional = true +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "asknews-0.7.47-py3-none-any.whl", hash = "sha256:f11569bd7f488a9097e20996d5bb4f7f240a684fd928d59db4fe4459903b5bbb"}, +] + +[package.dependencies] +asgiref = ">=3.7.2,<4.0.0" +cryptography = ">=40.0.0,<42.0.7" +httpx = ">=0.27.2,<0.28.0" +orjson = ">=3.9.10,<4.0.0" +pydantic = ">=2.5.3,<3.0.0" + [[package]] name = "asttokens" version = "2.4.1" @@ -1111,38 +1145,43 @@ toml = ["tomli"] [[package]] name = "cryptography" -version = "43.0.1" +version = "42.0.6" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = true python-versions = ">=3.7" files = [ - {file = "cryptography-43.0.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d"}, - {file = "cryptography-43.0.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062"}, - {file = "cryptography-43.0.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962"}, - {file = "cryptography-43.0.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277"}, - {file = "cryptography-43.0.1-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a"}, - {file = "cryptography-43.0.1-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042"}, - {file = "cryptography-43.0.1-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494"}, - {file = "cryptography-43.0.1-cp37-abi3-win32.whl", hash = "sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2"}, - {file = "cryptography-43.0.1-cp37-abi3-win_amd64.whl", hash = "sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d"}, - {file = "cryptography-43.0.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d"}, - {file = "cryptography-43.0.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806"}, - {file = "cryptography-43.0.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85"}, - {file = "cryptography-43.0.1-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c"}, - {file = "cryptography-43.0.1-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1"}, - {file = "cryptography-43.0.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa"}, - {file = "cryptography-43.0.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4"}, - {file = "cryptography-43.0.1-cp39-abi3-win32.whl", hash = "sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47"}, - {file = "cryptography-43.0.1-cp39-abi3-win_amd64.whl", hash = "sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb"}, - {file = "cryptography-43.0.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034"}, - {file = "cryptography-43.0.1-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d"}, - {file = "cryptography-43.0.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289"}, - {file = "cryptography-43.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84"}, - {file = "cryptography-43.0.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365"}, - {file = "cryptography-43.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96"}, - {file = "cryptography-43.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172"}, - {file = "cryptography-43.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2"}, - {file = "cryptography-43.0.1.tar.gz", hash = "sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d"}, + {file = "cryptography-42.0.6-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:073104df012fc815eed976cd7d0a386c8725d0d0947cf9c37f6c36a6c20feb1b"}, + {file = "cryptography-42.0.6-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:5967e3632f42b0c0f9dc2c9da88c79eabdda317860b246d1fbbde4a8bbbc3b44"}, + {file = "cryptography-42.0.6-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99831397fdc6e6e0aa088b060c278c6e635d25c0d4d14bdf045bf81792fda0a"}, + {file = "cryptography-42.0.6-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:089aeb297ff89615934b22c7631448598495ffd775b7d540a55cfee35a677bf4"}, + {file = "cryptography-42.0.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:97eeacae9aa526ddafe68b9202a535f581e21d78f16688a84c8dcc063618e121"}, + {file = "cryptography-42.0.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f4cece02478d73dacd52be57a521d168af64ae03d2a567c0c4eb6f189c3b9d79"}, + {file = "cryptography-42.0.6-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:aeb6f56b004e898df5530fa873e598ec78eb338ba35f6fa1449970800b1d97c2"}, + {file = "cryptography-42.0.6-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:8b90c57b3cd6128e0863b894ce77bd36fcb5f430bf2377bc3678c2f56e232316"}, + {file = "cryptography-42.0.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d16a310c770cc49908c500c2ceb011f2840674101a587d39fa3ea828915b7e83"}, + {file = "cryptography-42.0.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e3442601d276bd9e961d618b799761b4e5d892f938e8a4fe1efbe2752be90455"}, + {file = "cryptography-42.0.6-cp37-abi3-win32.whl", hash = "sha256:00c0faa5b021457848d031ecff041262211cc1e2bce5f6e6e6c8108018f6b44a"}, + {file = "cryptography-42.0.6-cp37-abi3-win_amd64.whl", hash = "sha256:b16b90605c62bcb3aa7755d62cf5e746828cfc3f965a65211849e00c46f8348d"}, + {file = "cryptography-42.0.6-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:eecca86813c6a923cabff284b82ff4d73d9e91241dc176250192c3a9b9902a54"}, + {file = "cryptography-42.0.6-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d93080d2b01b292e7ee4d247bf93ed802b0100f5baa3fa5fd6d374716fa480d4"}, + {file = "cryptography-42.0.6-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff75b88a4d273c06d968ad535e6cb6a039dd32db54fe36f05ed62ac3ef64a44"}, + {file = "cryptography-42.0.6-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c05230d8aaaa6b8ab3ab41394dc06eb3d916131df1c9dcb4c94e8f041f704b74"}, + {file = "cryptography-42.0.6-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9184aff0856261ecb566a3eb26a05dfe13a292c85ce5c59b04e4aa09e5814187"}, + {file = "cryptography-42.0.6-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:4bdb39ecbf05626e4bfa1efd773bb10346af297af14fb3f4c7cb91a1d2f34a46"}, + {file = "cryptography-42.0.6-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:e85f433230add2aa26b66d018e21134000067d210c9c68ef7544ba65fc52e3eb"}, + {file = "cryptography-42.0.6-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:65d529c31bd65d54ce6b926a01e1b66eacf770b7e87c0622516a840e400ec732"}, + {file = "cryptography-42.0.6-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f1e933b238978ccfa77b1fee0a297b3c04983f4cb84ae1c33b0ea4ae08266cc9"}, + {file = "cryptography-42.0.6-cp39-abi3-win32.whl", hash = "sha256:bc954251edcd8a952eeaec8ae989fec7fe48109ab343138d537b7ea5bb41071a"}, + {file = "cryptography-42.0.6-cp39-abi3-win_amd64.whl", hash = "sha256:9f1a3bc2747166b0643b00e0b56cd9b661afc9d5ff963acaac7a9c7b2b1ef638"}, + {file = "cryptography-42.0.6-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:945a43ebf036dd4b43ebfbbd6b0f2db29ad3d39df824fb77476ca5777a9dde33"}, + {file = "cryptography-42.0.6-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:f567a82b7c2b99257cca2a1c902c1b129787278ff67148f188784245c7ed5495"}, + {file = "cryptography-42.0.6-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3b750279f3e7715df6f68050707a0cee7cbe81ba2eeb2f21d081bd205885ffed"}, + {file = "cryptography-42.0.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6981acac509cc9415344cb5bfea8130096ea6ebcc917e75503143a1e9e829160"}, + {file = "cryptography-42.0.6-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:076c92b08dd1ab88108bc84545187e10d3693a9299c593f98c4ea195a0b0ead7"}, + {file = "cryptography-42.0.6-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:81dbe47e28b703bc4711ac74a64ef8b758a0cf056ce81d08e39116ab4bc126fa"}, + {file = "cryptography-42.0.6-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e1f5f15c5ddadf6ee4d1d624a2ae940f14bd74536230b0056ccb28bb6248e42a"}, + {file = "cryptography-42.0.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:43e521f21c2458038d72e8cdfd4d4d9f1d00906a7b6636c4272e35f650d1699b"}, + {file = "cryptography-42.0.6.tar.gz", hash = "sha256:f987a244dfb0333fbd74a691c36000a2569eaf7c7cc2ac838f85f59f0588ddc9"}, ] [package.dependencies] @@ -1155,7 +1194,7 @@ nox = ["nox"] pep8test = ["check-sdist", "click", "mypy", "ruff"] sdist = ["build"] ssh = ["bcrypt (>=3.1.5)"] -test = ["certifi", "cryptography-vectors (==43.0.1)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] [[package]] @@ -7505,13 +7544,13 @@ files = [ [[package]] name = "setuptools" -version = "75.1.0" +version = "75.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = true python-versions = ">=3.8" files = [ - {file = "setuptools-75.1.0-py3-none-any.whl", hash = "sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2"}, - {file = "setuptools-75.1.0.tar.gz", hash = "sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538"}, + {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, + {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, ] [package.extras] @@ -9483,7 +9522,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["PyMuPDF", "accelerate", "agentops", "arxiv", "arxiv2text", "azure-storage-blob", "beautifulsoup4", "botocore", "cohere", "datasets", "diffusers", "discord.py", "docker", "docx2txt", "duckduckgo-search", "firecrawl-py", "google-cloud-storage", "google-generativeai", "googlemaps", "imageio", "jupyter_client", "litellm", "mistralai", "nebula3-python", "neo4j", "newspaper3k", "nltk", "openapi-spec-validator", "opencv-python", "pillow", "prance", "praw", "pyTelegramBotAPI", "pydub", "pygithub", "pymilvus", "pyowm", "qdrant-client", "rank-bm25", "redis", "reka-api", "requests_oauthlib", "scholarly", "sentence-transformers", "sentencepiece", "slack-sdk", "soundfile", "textblob", "torch", "transformers", "unstructured", "wikipedia", "wolframalpha"] +all = ["PyMuPDF", "accelerate", "agentops", "arxiv", "arxiv2text", "asknews", "azure-storage-blob", "beautifulsoup4", "botocore", "cohere", "datasets", "diffusers", "discord.py", "docker", "docx2txt", "duckduckgo-search", "firecrawl-py", "google-cloud-storage", "google-generativeai", "googlemaps", "imageio", "jupyter_client", "litellm", "mistralai", "nebula3-python", "neo4j", "newspaper3k", "nltk", "openapi-spec-validator", "opencv-python", "pillow", "prance", "praw", "pyTelegramBotAPI", "pydub", "pygithub", "pymilvus", "pyowm", "qdrant-client", "rank-bm25", "redis", "reka-api", "requests_oauthlib", "scholarly", "sentence-transformers", "sentencepiece", "slack-sdk", "soundfile", "textblob", "torch", "transformers", "unstructured", "wikipedia", "wolframalpha"] encoders = ["sentence-transformers"] graph-storages = ["nebula3-python", "neo4j"] huggingface-agent = ["accelerate", "datasets", "diffusers", "opencv-python", "sentencepiece", "soundfile", "torch", "transformers"] @@ -9494,10 +9533,10 @@ rag = ["cohere", "nebula3-python", "neo4j", "pymilvus", "qdrant-client", "rank-b retrievers = ["cohere", "rank-bm25"] search-tools = ["duckduckgo-search", "wikipedia", "wolframalpha"] test = ["mock", "pytest", "pytest-asyncio"] -tools = ["PyMuPDF", "agentops", "arxiv", "arxiv2text", "beautifulsoup4", "discord.py", "docker", "docx2txt", "duckduckgo-search", "firecrawl-py", "googlemaps", "imageio", "jupyter_client", "newspaper3k", "nltk", "openapi-spec-validator", "pillow", "prance", "praw", "pyTelegramBotAPI", "pydub", "pygithub", "pyowm", "requests_oauthlib", "scholarly", "slack-sdk", "textblob", "unstructured", "wikipedia", "wolframalpha"] +tools = ["PyMuPDF", "agentops", "arxiv", "arxiv2text", "asknews", "beautifulsoup4", "discord.py", "docker", "docx2txt", "duckduckgo-search", "firecrawl-py", "googlemaps", "imageio", "jupyter_client", "newspaper3k", "nltk", "openapi-spec-validator", "pillow", "prance", "praw", "pyTelegramBotAPI", "pydub", "pygithub", "pyowm", "requests_oauthlib", "scholarly", "slack-sdk", "textblob", "unstructured", "wikipedia", "wolframalpha"] vector-databases = ["pymilvus", "qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<=3.13" -content-hash = "dd92185c6f132df5561577938bce32b57b798166241971f02b5a6aa418459efe" +content-hash = "110607e2b428e91354db877519ffbaafe35c9b7080b6fa12bbbabe46ffb4f4ae" diff --git a/pyproject.toml b/pyproject.toml index ae3636ed0d..c4f9c513ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ azure-storage-blob = { version = "^12.21.0", optional = true } google-cloud-storage = { version = "^2.18.0", optional = true } botocore = { version = "^1.35.3", optional = true } nltk = { version = "3.8.1", optional = true } +asknews = { version = "^0.7.43", optional = true } praw = { version = "^7.7.1", optional = true } textblob = { version = "^0.18.0.post0", optional = true } scholarly = { extras = ["tor"], version = "1.7.11", optional = true } @@ -188,6 +189,7 @@ tools = [ "docker", "jupyter_client", "agentops", + "asknews", "praw", "textblob", "scholarly", @@ -258,6 +260,7 @@ all = [ "agentops", "praw", "textblob", + "asknews", "scholarly", # vector-database "qdrant-client", @@ -418,6 +421,7 @@ module = [ "reka-api", "agentops", "botocore.*", + "asknews", "arxiv", "arxiv2text", "praw", diff --git a/test/toolkits/test_asknews_function.py b/test/toolkits/test_asknews_function.py new file mode 100644 index 0000000000..b1e724d1b4 --- /dev/null +++ b/test/toolkits/test_asknews_function.py @@ -0,0 +1,148 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import os +import unittest +from unittest.mock import MagicMock, patch + +from camel.toolkits.ask_news_toolkit import AskNewsToolkit, _process_response + + +class TestAskNewsToolkit(unittest.TestCase): + @patch.dict( + os.environ, + { + "ASKNEWS_CLIENT_ID": "fake_client_id", + "ASKNEWS_CLIENT_SECRET": "fake_client_secret", + }, + ) + @patch("asknews_sdk.AskNewsSDK") + def setUp(self, MockAskNewsSDK): + # Setup for tests + self.mock_sdk = MockAskNewsSDK.return_value + self.toolkit = AskNewsToolkit() + + def test_get_news_success(self): + # Mock the API response for a successful get_news call + mock_response = MagicMock() + mock_response.as_string = "News in string format" + self.mock_sdk.news.search_news.return_value = mock_response + + result = self.toolkit.get_news( + query="test query", return_type="string" + ) + + self.assertEqual(result, "News in string format") + self.mock_sdk.news.search_news.assert_called_once_with( + query="test query", + n_articles=10, + return_type="string", + method="kw", + ) + + def test_get_news_failure(self): + # Test handling of an exception in get_news + self.mock_sdk.news.search_news.side_effect = Exception("API Error") + result = self.toolkit.get_news(query="test query") + self.assertEqual(result, "Got error: API Error") + + def test_search_reddit_success(self): + # Mock the API response for search_reddit + mock_response = MagicMock() + mock_response.as_string = "Reddit threads in string format" + self.mock_sdk.news.search_reddit.return_value = mock_response + + result = self.toolkit.search_reddit( + keywords=["test"], n_threads=5, return_type="string" + ) + + self.assertEqual(result, "Reddit threads in string format") + self.mock_sdk.news.search_reddit.assert_called_once_with( + keywords=["test"], n_threads=5, method="kw" + ) + + def test_get_stories_success(self): + # Mock the API response for get_stories + mock_story = MagicMock() + mock_story.updates = [ + MagicMock(headline="Update 1 headline", story="Update 1 story"), + MagicMock(headline="Update 2 headline", story="Update 2 story"), + ] + mock_response = MagicMock() + mock_response.stories = [mock_story] + self.mock_sdk.stories.search_stories.return_value = mock_response + + result = self.toolkit.get_stories( + query="test query", categories=["Sports"] + ) + + expected_result = { + "stories": [ + { + "headline": "Update 1 headline", + "updates": [ + { + "headline": "Update 1 headline", + "story": "Update 1 story", + }, + { + "headline": "Update 2 headline", + "story": "Update 2 story", + }, + ], + } + ] + } + self.assertEqual(result, expected_result) + self.mock_sdk.stories.search_stories.assert_called_once_with( + query="test query", + categories=["Sports"], + reddit=3, + expand_updates=True, + max_updates=2, + max_articles=10, + ) + + def test_get_stories_failure(self): + # Test handling of an exception in get_stories + self.mock_sdk.stories.search_stories.side_effect = Exception( + "API Error" + ) + result = self.toolkit.get_stories( + query="test query", categories=["Sports"] + ) + self.assertEqual(result, "Got error: API Error") + + def test_process_response(self): + # Test _process_response utility function + mock_response = MagicMock() + mock_response.as_string = "response in string" + mock_response.as_dicts = {"response": "in dict"} + + # Test for string return type + result = _process_response(mock_response, "string") + self.assertEqual(result, "response in string") + + # Test for dicts return type + result = _process_response(mock_response, "dicts") + self.assertEqual(result, {"response": "in dict"}) + + # Test for both return type + result = _process_response(mock_response, "both") + self.assertEqual( + result, ("response in string", {"response": "in dict"}) + ) + + # Test for invalid return type + with self.assertRaises(ValueError): + _process_response(mock_response, "invalid_type")