Skip to content

Commit 7ea31ef

Browse files
Anna050689Hanna Imshenetska
and
Hanna Imshenetska
authored
refactor the code related to Streamlit UI (#385)
* move frontend part of the code to 'streamlit_app/css/style.css' * refactor the code related to streamlit UI * fix issues raised by 'flake8' * update 'VERSION' * add the hint of the version of syngen library to the title of Streamlit UI page * minor changes in 'streamlit_app/run.py' * add changes to support the fetching 'syngen' version * minor changes in 'src/setup.py', 'src/syngen/VERSION' --------- Co-authored-by: Hanna Imshenetska <[email protected]@EVZZAMZSA0021.epam.com>
1 parent 8a8bc58 commit 7ea31ef

File tree

13 files changed

+396
-330
lines changed

13 files changed

+396
-330
lines changed

setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ where = src
7878
console_scripts =
7979
train = syngen.train:launch_train
8080
infer = syngen.infer:launch_infer
81+
syngen = syngen:main

src/setup.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from setuptools import setup, find_packages
22

3-
with open("../src/syngen/VERSION", "r") as file:
4-
version_info = file.read()
3+
from syngen import __version__
4+
5+
56
setup(
67
name="syngen",
7-
version=version_info,
8+
version=__version__,
89
packages=find_packages(),
910
include_package_data=True
1011
)

src/syngen/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.7.19
1+
0.7.20

src/syngen/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -1 +1,22 @@
1+
import os
2+
import argparse
3+
14
from syngen.train import preprocess_data # noqa: F401
5+
6+
7+
base_dir = os.path.dirname(__file__)
8+
version_file = os.path.join(base_dir, "VERSION")
9+
10+
with open(version_file) as f:
11+
__version__ = f.read().strip()
12+
13+
14+
def main():
15+
parser = argparse.ArgumentParser(prog="syngen")
16+
parser.add_argument(
17+
"--version",
18+
action="version",
19+
version="%(prog)s " + __version__
20+
)
21+
args = parser.parse_args()
22+
return args

src/syngen/streamlit_app/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from syngen.streamlit_app.start import start
1+
from syngen.streamlit_app.start import start # noqa: F401

src/syngen/streamlit_app/css/font_style.css src/syngen/streamlit_app/css/style.css

+45
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from syngen.streamlit_app.handlers.handlers import StreamlitHandler # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import os
2+
from datetime import datetime
3+
import traceback
4+
from queue import Queue
5+
6+
from loguru import logger
7+
from slugify import slugify
8+
import streamlit as st
9+
10+
from syngen.ml.worker import Worker
11+
from syngen.ml.utils import fetch_log_message, ProgressBarHandler
12+
13+
UPLOAD_DIRECTORY = "uploaded_files"
14+
TIMESTAMP = slugify(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
15+
16+
17+
class StreamlitHandler:
18+
"""
19+
A class for handling the Streamlit app
20+
"""
21+
22+
def __init__(self, uploaded_file, epochs: int, size_limit: int, print_report: bool):
23+
self.log_queue = Queue()
24+
self.progress_handler = ProgressBarHandler()
25+
self.log_error_queue = Queue()
26+
self.epochs = epochs
27+
self.size_limit = size_limit
28+
self.print_report = print_report
29+
self.file_name = uploaded_file.name
30+
self.table_name = os.path.splitext(self.file_name)[0]
31+
self.file_path = os.path.join(UPLOAD_DIRECTORY, self.file_name)
32+
self.sl_table_name = slugify(self.table_name)
33+
self.path_to_generated_data = (f"model_artifacts/tmp_store/{self.sl_table_name}/"
34+
f"merged_infer_{self.sl_table_name}.csv")
35+
self.path_to_report = (f"model_artifacts/tmp_store/{self.sl_table_name}/"
36+
f"draws/accuracy_report.html")
37+
38+
def set_logger(self):
39+
"""
40+
Set a logger to see logs, and collect log messages
41+
with the log level - 'INFO' in a log file and stdout
42+
"""
43+
logger.add(self.file_sink, level="INFO")
44+
logger.add(self.log_sink, level="INFO")
45+
46+
def log_sink(self, message):
47+
"""
48+
Put log messages to a log queue
49+
"""
50+
log_message = fetch_log_message(message)
51+
self.log_queue.put(log_message)
52+
53+
def file_sink(self, message):
54+
"""
55+
Write log messages to a log file
56+
"""
57+
path_to_logs = f"model_artifacts/tmp_store/{self.sl_table_name}_{TIMESTAMP}.log"
58+
os.environ["SUCCESS_LOG_FILE"] = path_to_logs
59+
os.makedirs(os.path.dirname(path_to_logs), exist_ok=True)
60+
with open(path_to_logs, "a") as log_file:
61+
log_message = fetch_log_message(message)
62+
log_file.write(log_message + "\n")
63+
64+
def train_model(self):
65+
"""
66+
Launch a model training
67+
"""
68+
try:
69+
self.set_logger()
70+
logger.info("Starting model training...")
71+
settings = {
72+
"source": self.file_path,
73+
"epochs": self.epochs,
74+
"row_limit": 10000,
75+
"drop_null": False,
76+
"batch_size": 32,
77+
"print_report": False
78+
}
79+
worker = Worker(
80+
table_name=self.table_name,
81+
settings=settings,
82+
metadata_path=None,
83+
log_level="INFO",
84+
type_of_process="train"
85+
)
86+
ProgressBarHandler().set_progress(0.01)
87+
worker.launch_train()
88+
logger.info("Model training completed")
89+
except Exception:
90+
logger.error(f"Error during train: {traceback.format_exc()}")
91+
self.log_error_queue.put(f"Error during train: {traceback.format_exc()}")
92+
93+
def infer_model(self):
94+
"""
95+
Launch a data generation
96+
"""
97+
try:
98+
logger.info("Starting data generation...")
99+
settings = {
100+
"size": self.size_limit,
101+
"batch_size": 32,
102+
"run_parallel": False,
103+
"random_seed": None,
104+
"print_report": self.print_report,
105+
"get_infer_metrics": False
106+
}
107+
worker = Worker(
108+
table_name=self.table_name,
109+
settings=settings,
110+
metadata_path=None,
111+
log_level="INFO",
112+
type_of_process="infer"
113+
)
114+
worker.launch_infer()
115+
logger.info("Data generation completed")
116+
except Exception:
117+
logger.error(f"Error during infer: {traceback.format_exc()}")
118+
self.log_error_queue.put(f"Error during infer: {traceback.format_exc()}")
119+
120+
def train_and_infer(self):
121+
"""
122+
Launch a model training and data generation
123+
"""
124+
self.train_model()
125+
self.infer_model()
126+
127+
@staticmethod
128+
def generate_button(label, path_to_file, download_name):
129+
"""
130+
Generate a download button
131+
"""
132+
if os.path.exists(path_to_file):
133+
with open(path_to_file, "rb") as f:
134+
st.download_button(
135+
label,
136+
f,
137+
file_name=download_name,
138+
)
139+
140+
def generate_buttons(self):
141+
"""
142+
Generate download buttons for downloading artifacts
143+
"""
144+
self.generate_button(
145+
"Download generated data",
146+
self.path_to_generated_data,
147+
f"generated_data_{self.sl_table_name}.csv"
148+
)
149+
self.generate_button(
150+
"Download logs",
151+
os.getenv("SUCCESS_LOG_FILE", ""),
152+
f"logs_{self.sl_table_name}.log"
153+
)
154+
if self.print_report:
155+
self.generate_button(
156+
"Download report",
157+
self.path_to_report,
158+
f"accuracy_report_{self.sl_table_name}.html"
159+
)

0 commit comments

Comments
 (0)