-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdermassist_streamlit.py
115 lines (96 loc) · 4.01 KB
/
dermassist_streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import streamlit as st
from PIL import Image
from langchain_core.messages import HumanMessage, AIMessage
from rag_system import RAG
from vision_model import VisionModel
class DermAssist:
def __init__(self, image_save_dir, dermassist_logo, skin_disease=None):
self.dermassist_logo = dermassist_logo
self.uploaded_file = None
self.image = None
self.image_save_dir = image_save_dir
self.image_save_path = None
self.rag = RAG()
self.vision_model = VisionModel()
self.skin_disease = skin_disease
def create_directory(self):
if not os.path.exists(self.image_save_dir):
os.makedirs(self.image_save_dir)
def save_image(self):
if self.image:
self.create_directory()
self.image_save_path = os.path.join(self.image_save_dir, self.uploaded_file.name)
with open(self.image_save_path, "wb") as f:
f.write(self.uploaded_file.getbuffer())
def display_image(self):
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.image(self.image, caption="Uploaded Image", use_column_width=True)
def upload_image(self):
self.uploaded_file = st.file_uploader("Upload an image of the affected area", type=["jpg", "jpeg", "png"])
if self.uploaded_file is not None:
self.image = Image.open(self.uploaded_file)
self.save_image()
return True
return False
@staticmethod
def initialize_chat_history(skin_disease):
if "chat_history" not in st.session_state:
skin_diseases = " and ".join(skin_disease)
initial_context = f"You are suffering with {skin_diseases}"
st.session_state.chat_history = [AIMessage(content=initial_context)]
@staticmethod
def display_chat():
for message in st.session_state.chat_history:
if isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
else:
with st.chat_message("AI"):
st.markdown(message.content)
def handle_user_input(self):
user_input = st.chat_input("Ask a question")
if user_input is not None and user_input != "":
st.session_state.chat_history.append(HumanMessage(content=user_input))
with st.chat_message("Human"):
st.markdown(user_input)
with st.chat_message("AI"):
rag_response = st.write_stream(
self.rag.generate_response_streamlit(user_input, st.session_state.chat_history))
st.session_state.chat_history.append(AIMessage(content=rag_response))
def setup_page(self):
st.set_page_config(page_title="DermAssist", page_icon="⚕️")
left, center, right = st.columns(3)
logo = Image.open(self.dermassist_logo)
with center:
st.image(logo, width=200)
st.markdown(
"""
<h1 style="text-align: center;">
<span style="color: white;">Derm</span>
<span style="color: red;">Assist</span>
</h1>
<h3 style="text-align: center;">
Your AI Assistant for Skin Problems
</h3>
<br>
""",
unsafe_allow_html=True
)
def run(self):
self.setup_page()
if self.upload_image():
self.display_image()
else:
st.info("Please upload an image of the affected area to perform diagnosis and ask questions.")
st.stop()
self.skin_disease = self.vision_model.predict(self.image_save_path)
self.initialize_chat_history(skin_disease=self.skin_disease)
self.display_chat()
self.handle_user_input()
if __name__ == '__main__':
images_folder = "./images"
dermassist_logo_path = "./media/derm_assist_logo.png"
dermassist = DermAssist(image_save_dir=images_folder, dermassist_logo=dermassist_logo_path)
dermassist.run()