-
Notifications
You must be signed in to change notification settings - Fork 261
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: finbert integration with uagents (#151)
- Loading branch information
1 parent
f505b8d
commit b70c04c
Showing
10 changed files
with
1,435 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Finbert Integration | ||
|
||
FinBERT is a pre-trained NLP model to analyze sentiment of financial text. It is built by further training the BERT language model in the finance domain, using a large financial corpus and thereby fine-tuning it for financial sentiment classification. | ||
|
||
Ensure you have the following software installed: | ||
|
||
- Python (v3.10+ recommended) | ||
- Poetry (A Python packaging and dependency management tool) | ||
|
||
## Setup | ||
|
||
1. For the demo to work, you need to get HuggingFaceAPI Token: | ||
|
||
1. Visit [HuggingFace](https://huggingface.co/). | ||
2. Sign up or log in. | ||
3. Navigate to `Profile -> Settings -> Access Tokens`. | ||
4. Copy an existing token or create a new one. | ||
|
||
2. In the `finbert/src` directory, create a `.env` file and set your HuggingFaceAPI Token: | ||
|
||
``` | ||
export HUGGING_FACE_ACCESS_TOKEN="{Your HuggingFaceAPI Token}" | ||
``` | ||
3. To load the environment variables from `.env` and install the dependencies: | ||
```bash | ||
cd src | ||
source .env | ||
poetry install | ||
``` | ||
## Running The Script | ||
To run the project, use the command: | ||
``` | ||
poetry run python main.py | ||
``` | ||
In the console look for the following output: | ||
``` | ||
Adding agent to Bureau: {agent_address} | ||
|
||
``` | ||
Copy the {agent_address} address and replace **AI_MODEL_AGENT_ADDRESS** value in agents/finbert_user.py. | ||
You can change the input text to what you like, just open the src/finbert_user.py file, and change the value of **INPUT_TEXT** variable in agents/finbert_user.py. | ||
The finbert model will give softmax outputs for three labels: positive, negative or neutral. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[tool.poetry] | ||
name = "finbert-integration" | ||
version = "0.0.1" | ||
description = "showcasing finbert integration with uagents" | ||
authors = ["Sangram Singh"] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10,<3.12" | ||
uagents = "*" | ||
requests = "^2.31.0" |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Necessary imports: uagents for agent creation and message handling, | ||
# os and requests for managing API calls | ||
from uagents import Agent, Context, Protocol | ||
from messages.basic import UAResponse, UARequest, Error | ||
from uagents.setup import fund_agent_if_low | ||
import os | ||
import requests | ||
|
||
# The access token and URL for the FinBERT model, served by Hugging Face | ||
HUGGING_FACE_ACCESS_TOKEN = os.getenv( | ||
"HUGGING_FACE_ACCESS_TOKEN", "HUGGING FACE secret phrase :)") | ||
FINBERT_URL = "https://api-inference.huggingface.co/models/ProsusAI/finbert" | ||
|
||
# Setting the headers for the API call | ||
HEADERS = { | ||
"Authorization": f"Bearer {HUGGING_FACE_ACCESS_TOKEN}" | ||
} | ||
|
||
# Creating the agent and funding it if necessary | ||
agent = Agent( | ||
name="finbert_agent", | ||
seed=HUGGING_FACE_ACCESS_TOKEN, | ||
port=8001, | ||
endpoint=["http://127.0.0.1:8001/submit"], | ||
) | ||
fund_agent_if_low(agent.wallet.address()) | ||
|
||
# Function to get classification results from FinBERT for a given text | ||
|
||
|
||
async def get_classification(ctx: Context, sender: str, text: str): | ||
data = { | ||
"inputs": text | ||
} | ||
|
||
try: | ||
# Making POST request to Hugging Face FinBERT API | ||
response = requests.post(FINBERT_URL, headers=HEADERS, json=data) | ||
|
||
if response.status_code != 200: | ||
# Error handling - send error message back to user if API call unsuccessful | ||
await ctx.send(sender, Error(error=f"Error: {response.json().get('error')}")) | ||
return | ||
# If API call is successful, return the response from the model | ||
model_res = response.json()[0] | ||
await ctx.send(sender, UAResponse(response=model_res)) | ||
return | ||
except Exception as ex: | ||
# Catch and notify any exception occured during API call or data handling | ||
await ctx.send(sender, Error(error=f"An exception occurred while processing the request: {ex}")) | ||
return | ||
|
||
|
||
# Protocol declaration for UARequests | ||
finbert_agent = Protocol("UARequest") | ||
|
||
# Declaration of a message event handler for handling UARequests and send respective response. | ||
|
||
|
||
@finbert_agent.on_message(model=UARequest, replies={UAResponse, Error}) | ||
async def handle_request(ctx: Context, sender: str, request: UARequest): | ||
# Logging the request information | ||
ctx.logger.info( | ||
f"Got request from {sender} for text classification : {request.text}") | ||
|
||
# Call text classification function for the incoming request's text | ||
await get_classification(ctx, sender, request.text) | ||
|
||
# Include protocol to the agent | ||
agent.include(finbert_agent) | ||
|
||
# If the script is run as the main program, run our agents event loop | ||
if __name__ == "__main__": | ||
finbert_agent.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from uagents import Agent, Context, Protocol # Import necessary modules | ||
from messages.basic import UAResponse, UARequest, Error # Import Basic messages | ||
# Import the fund_agent_if_low function | ||
from uagents.setup import fund_agent_if_low | ||
import os # Import os for environment variables if any needed | ||
|
||
INPUT_TEXT = "Stocks rallied and the British pound gained." | ||
AI_MODEL_AGENT_ADDRESS = "agent1q2gdm20yg3krr0k5nt8kf794ftm46x2cs7xyrm2gqc9ga5lulz9xsp4l4kk" | ||
|
||
user = Agent( | ||
name="finbert_user", | ||
port=8000, | ||
endpoint=["http://127.0.0.1:8000/submit"], | ||
) | ||
|
||
fund_agent_if_low(user.wallet.address()) | ||
|
||
finbert_user = Protocol("Request") | ||
|
||
|
||
@finbert_user.on_interval(360, messages=UARequest) | ||
async def text_classification(ctx: Context): | ||
ctx.logger.info(f"Asking AI model agent to classify: {INPUT_TEXT}") | ||
await ctx.send(AI_MODEL_AGENT_ADDRESS, UARequest(text=INPUT_TEXT)) | ||
|
||
|
||
@finbert_user.on_message(model=UAResponse) | ||
async def handle_data(ctx: Context, sender: str, data: UAResponse): | ||
ctx.logger.info(f"Got response from AI model agent: {data.response}") | ||
|
||
|
||
@finbert_user.on_message(model=Error) | ||
async def handle_error(ctx: Context, sender: str, error: Error): | ||
ctx.logger.info(f"Got error from AI model agent: {error}") | ||
|
||
user.include(finbert_user) | ||
|
||
if __name__ == "__main__": | ||
finbert_user.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from uagents import Bureau | ||
|
||
from agents.finbert_agent import agent | ||
from agents.finbert_user import user | ||
|
||
|
||
if __name__ == "__main__": | ||
bureau = Bureau(endpoint="http://127.0.0.1:8000/submit", port=8000) | ||
print(f"Adding agent to Bureau: {agent.address}") | ||
bureau.add(agent) | ||
print(f"Adding user agent to Bureau: {user.address}") | ||
bureau.add(user) | ||
bureau.run() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from uagents import Model | ||
|
||
|
||
class UARequest(Model): | ||
text: str | ||
|
||
|
||
class Error(Model): | ||
error: str | ||
|
||
|
||
class UAResponse(Model): | ||
response: list |