Skip to content
This repository has been archived by the owner on Mar 26, 2024. It is now read-only.

Commit

Permalink
feat: added the toxicity for the ai bot (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
izam-mohammed committed Mar 9, 2024
2 parents 7f4a41f + 7a691e4 commit c85e29d
Show file tree
Hide file tree
Showing 14 changed files with 145 additions and 19 deletions.
1 change: 1 addition & 0 deletions ai_server/ai_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"rest_framework",
"corsheaders",
"chat",
"toxicity",
]

MIDDLEWARE = [
Expand Down
3 changes: 2 additions & 1 deletion ai_server/ai_server/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

urlpatterns = [
# path("admin/", admin.site.urls),
path("", include("chat.urls"))
path("chat", include("chat.urls")),
path("toxic", include("toxicity.urls")),
]

handler404 = "utils.exception_handler.error_404"
Expand Down
33 changes: 24 additions & 9 deletions ai_server/chat/llm_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,41 @@
from langchain_core.output_parsers import StrOutputParser
from .prompt import PROMPT

def call(input: str) -> Dict[str, Union[bool, str]]:
if "OPENAI_API_KEY" not in os.environ:
def call(user_input: str) -> Dict[str, Union[bool, str]]:
"""
Calls the OpenAI Chat API to generate a response.
Args:
user_input (str): The input text to generate a response for.
Returns:
Dict[str, Union[bool, str]]: A dictionary containing the success status and the generated response.
"""
if "OPENAI_API_KEY" not in os.environ or not os.environ["OPENAI_API_KEY"]:
return {
"success": False,
"message": "OpenAI API key not found",
}
"message": "OpenAI API key not found or empty",
}

prompt = ChatPromptTemplate.from_template(PROMPT)
model = ChatOpenAI(temperature=0)
chain = prompt | model | StrOutputParser()

response:str = chain.invoke({"prompt": input})
chain = prompt | model | StrOutputParser()

try:
response: str = chain.invoke({"prompt": user_input})
except Exception as e:
return {
"success": False,
"message": f"Error in generation: {str(e)}",
}

if not response:
return {
"success": False,
"message": "Error in generation",
}
}

return {
"success": True,
"message": response,
}
}
3 changes: 1 addition & 2 deletions ai_server/chat/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
from . import views

urlpatterns = [
path("", views.home, name="home"),
path('chat', views.index, name="routes"),
path("", views.index, name="chat"),
]
7 changes: 0 additions & 7 deletions ai_server/chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from rest_framework import status
from .llm_call import call as llm_call

@api_view(["GET", "POST"])
def home(request):
return Response(
{"success":True ,"message":"connected successfully"},
status=status.HTTP_200_OK
)

@api_view(["POST"])
def index(request):
content: dict = request.data
Expand Down
Empty file added ai_server/toxicity/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions ai_server/toxicity/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.contrib import admin

# Register your models here.
6 changes: 6 additions & 0 deletions ai_server/toxicity/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class ToxicityConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "toxicity"
57 changes: 57 additions & 0 deletions ai_server/toxicity/hf_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Dict, Union, List
import requests
import os

RESPONSE_TYPE = Dict[str, Union[str, float]]


def _transform_response(response: List[List[RESPONSE_TYPE]]) -> RESPONSE_TYPE:
"""
Transforms the response from the API call.
Args:
response (List[List[RESPONSE_TYPE]]): The response from the API call.
Returns:
RESPONSE_TYPE: The transformed response.
"""
if response[0][0]["score"] > response[0][1]["score"]:
return response[0][0]
return response[0][1]


def call(text: str) -> RESPONSE_TYPE:
"""
Calls the Hugging Face API to classify toxic comments.
Args:
text (str): The text to classify.
Returns:
RESPONSE_TYPE: The response from the API.
"""
if "HF_TOKEN" not in os.environ or not os.environ["HF_TOKEN"]:
return {
"success": False,
"message": "Huggingface Token not found or empty",
}

HF_TOKEN = os.environ["HF_TOKEN"]
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
api_url = (
"https://api-inference.huggingface.co/models/martin-ha/toxic-comment-model"
)

response = requests.post(api_url, headers=headers, json={"inputs": text})

if response.status_code != 200:
return {
"success": False,
"message": f"API returned status code {response.status_code}",
}

response_data = response.json()
transformed_response = _transform_response(response_data)
transformed_response["success"] = True

return transformed_response
Empty file.
3 changes: 3 additions & 0 deletions ai_server/toxicity/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.db import models

# Create your models here.
3 changes: 3 additions & 0 deletions ai_server/toxicity/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.test import TestCase

# Create your tests here.
6 changes: 6 additions & 0 deletions ai_server/toxicity/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.urls import path
from . import views

urlpatterns = [
path("", views.index, name="toxic"),
]
39 changes: 39 additions & 0 deletions ai_server/toxicity/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Dict, Union
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework import status
from .hf_call import call as hf_call

@api_view(["POST"])
def index(request):
content: dict = request.data
text = content.get("text")
if text is None:
return Response(
{"success": False, "message": "text is required !"},
status=status.HTTP_400_BAD_REQUEST,
)
if not isinstance(text, str):
return Response(
{"success": False, "message": "text should be str !"},
status=status.HTTP_400_BAD_REQUEST,
)
if not text:
return Response(
{"success": False, "message": "query should not be empty"},
status=status.HTTP_400_BAD_REQUEST,
)

response:Dict[str, Union[bool, str]] = hf_call(text=text)

if not response["success"]:
return Response(
response,
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

return Response(
response,
status=status.HTTP_200_OK,
)

0 comments on commit c85e29d

Please sign in to comment.