From 11283655fe985655a004065399219ac0609487a7 Mon Sep 17 00:00:00 2001 From: namastex888 <105755034+namastex888@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:33:28 -0300 Subject: [PATCH] feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods (#3738) * feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods * test: fix test * test: fix test units * test: fix import * fix: rename component * Fix instantiation of YfinanceToolComponent in complex_agent.py --------- Co-authored-by: italojohnny Co-authored-by: Gabriel Luiz Freitas Almeida --- .../langflow/components/tools/YfinanceTool.py | 108 ++++++++++++++---- .../components/tools/test_yfinance_tool.py | 2 +- 2 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/backend/base/langflow/components/tools/YfinanceTool.py b/src/backend/base/langflow/components/tools/YfinanceTool.py index 3ffd2062820..b32850e9862 100644 --- a/src/backend/base/langflow/components/tools/YfinanceTool.py +++ b/src/backend/base/langflow/components/tools/YfinanceTool.py @@ -1,35 +1,101 @@ -from typing import cast +import ast +import pprint -from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool +import yfinance as yf +from langchain.tools import StructuredTool +from pydantic import BaseModel, Field from langflow.base.langchain_utilities.model import LCToolComponent -from langflow.field_typing import Data, Tool -from langflow.inputs.inputs import MessageTextInput -from langflow.template.field.base import Output +from langflow.field_typing import Tool +from langflow.inputs import DropdownInput, IntInput, MessageTextInput +from langflow.schema import Data class YfinanceToolComponent(LCToolComponent): - display_name = "Yahoo Finance News Tool" - description = "Tool for interacting with Yahoo Finance News." - name = "YFinanceTool" + display_name = "Yahoo Finance Tool" + description = "Access financial data and market information using Yahoo Finance." + icon = "trending-up" + name = "YahooFinanceTool" inputs = [ MessageTextInput( - name="input_value", - display_name="Query", - info="Input should be a company ticker. For example, AAPL for Apple, MSFT for Microsoft.", - ) + name="symbol", + display_name="Stock Symbol", + info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).", + required=True, + ), + DropdownInput( + name="method", + display_name="Data Method", + info="The type of data to retrieve.", + options=[ + "get_actions", + "get_analysis", + "get_balance_sheet", + "get_calendar", + "get_cashflow", + "get_info", + "get_institutional_holders", + "get_news", + "get_recommendations", + "get_sustainability", + ], + value="get_news", + ), + IntInput( + name="num_news", + display_name="Number of News", + info="The number of news articles to retrieve (only applicable for get_news).", + value=5, + ), ] - outputs = [ - Output(name="api_run_model", display_name="Data", method="run_model"), - # Keep this for backwards compatibility - Output(name="tool", display_name="Tool", method="build_tool"), - ] + class YahooFinanceSchema(BaseModel): + symbol: str = Field(..., description="The stock symbol to retrieve data for.") + method: str = Field("get_info", description="The type of data to retrieve.") + num_news: int | None = Field(5, description="The number of news articles to retrieve.") + + def run_model(self) -> list[Data]: + return self._yahoo_finance_tool( + self.symbol, + self.method, + self.num_news, + ) def build_tool(self) -> Tool: - return cast(Tool, YahooFinanceNewsTool()) + return StructuredTool.from_function( + name="yahoo_finance", + description="Access financial data and market information from Yahoo Finance.", + func=self._yahoo_finance_tool, + args_schema=self.YahooFinanceSchema, + ) + + def _yahoo_finance_tool( + self, + symbol: str, + method: str, + num_news: int | None = 5, + ) -> list[Data]: + ticker = yf.Ticker(symbol) + + try: + if method == "get_info": + result = ticker.info + elif method == "get_news": + result = ticker.news[:num_news] + else: + result = getattr(ticker, method)() + + result = pprint.pformat(result) + + if method == "get_news": + data_list = [Data(data=article) for article in ast.literal_eval(result)] + else: + data_list = [Data(data={"result": result})] + + return data_list - def run_model(self) -> Data: - tool = self.build_tool() - return tool.run(self.input_value) + except Exception as e: + error_message = f"Error retrieving data: {str(e)}" + self.status = error_message + return [Data(data={"error": error_message})] diff --git a/src/backend/tests/unit/components/tools/test_yfinance_tool.py b/src/backend/tests/unit/components/tools/test_yfinance_tool.py index 7b19b72da1f..75630a42dae 100644 --- a/src/backend/tests/unit/components/tools/test_yfinance_tool.py +++ b/src/backend/tests/unit/components/tools/test_yfinance_tool.py @@ -17,5 +17,5 @@ def test_yfinance_tool_template(): assert "outputs" in frontend_node output_names = [output["name"] for output in frontend_node["outputs"]] assert "api_run_model" in output_names - assert "tool" in output_names + assert "api_build_tool" in output_names assert all(output["types"] != [] for output in frontend_node["outputs"])