Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods #3738

Merged
merged 6 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 87 additions & 21 deletions src/backend/base/langflow/components/tools/YfinanceTool.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,101 @@
from typing import cast
import ast
import pprint

from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool
import yfinance as yf
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field

from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.field_typing import Data, Tool
from langflow.inputs.inputs import MessageTextInput
from langflow.template.field.base import Output
from langflow.field_typing import Tool
from langflow.inputs import DropdownInput, IntInput, MessageTextInput
from langflow.schema import Data


class YfinanceToolComponent(LCToolComponent):
display_name = "Yahoo Finance News Tool"
description = "Tool for interacting with Yahoo Finance News."
name = "YFinanceTool"
display_name = "Yahoo Finance Tool"
description = "Access financial data and market information using Yahoo Finance."
icon = "trending-up"
name = "YahooFinanceTool"

inputs = [
MessageTextInput(
name="input_value",
display_name="Query",
info="Input should be a company ticker. For example, AAPL for Apple, MSFT for Microsoft.",
)
name="symbol",
display_name="Stock Symbol",
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
required=True,
),
DropdownInput(
name="method",
display_name="Data Method",
info="The type of data to retrieve.",
options=[
"get_actions",
"get_analysis",
"get_balance_sheet",
"get_calendar",
"get_cashflow",
"get_info",
"get_institutional_holders",
"get_news",
"get_recommendations",
"get_sustainability",
],
value="get_news",
),
IntInput(
name="num_news",
display_name="Number of News",
info="The number of news articles to retrieve (only applicable for get_news).",
value=5,
),
]

outputs = [
Output(name="api_run_model", display_name="Data", method="run_model"),
# Keep this for backwards compatibility
Output(name="tool", display_name="Tool", method="build_tool"),
]
class YahooFinanceSchema(BaseModel):
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
method: str = Field("get_info", description="The type of data to retrieve.")
num_news: int | None = Field(5, description="The number of news articles to retrieve.")

def run_model(self) -> list[Data]:
return self._yahoo_finance_tool(
self.symbol,
self.method,
self.num_news,
)

def build_tool(self) -> Tool:
return cast(Tool, YahooFinanceNewsTool())
return StructuredTool.from_function(
name="yahoo_finance",
description="Access financial data and market information from Yahoo Finance.",
func=self._yahoo_finance_tool,
args_schema=self.YahooFinanceSchema,
)

def _yahoo_finance_tool(
self,
symbol: str,
method: str,
num_news: int | None = 5,
) -> list[Data]:
ticker = yf.Ticker(symbol)

try:
if method == "get_info":
result = ticker.info
elif method == "get_news":
result = ticker.news[:num_news]
else:
result = getattr(ticker, method)()

result = pprint.pformat(result)

if method == "get_news":
data_list = [Data(data=article) for article in ast.literal_eval(result)]
else:
data_list = [Data(data={"result": result})]

return data_list

def run_model(self) -> Data:
tool = self.build_tool()
return tool.run(self.input_value)
except Exception as e:
error_message = f"Error retrieving data: {str(e)}"
self.status = error_message
return [Data(data={"error": error_message})]
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def test_yfinance_tool_template():
assert "outputs" in frontend_node
output_names = [output["name"] for output in frontend_node["outputs"]]
assert "api_run_model" in output_names
assert "tool" in output_names
assert "api_build_tool" in output_names
assert all(output["types"] != [] for output in frontend_node["outputs"])
Loading