Skip to content

Commit 4f9e723

Browse files
Anna050689serhio-kdenischernow
authored
implement progress bar (#362)
* Add flow for streamlit * Fix github action * Remove Dockerfiles and update setup.py and start.py * Add streamlit to Dockerfile * Update dependencies in setup.py * Update Dockerfile to install streamlit and altair<5 * Fix pip install command in Dockerfile * Update Dockerfile to install only streamlit package * Update mlflow version to 2.9.2 * Update streamlit command in README * Add streamlit app with file upload and model training * Update file uploader to only accept a single CSV file * Refactor file uploader and training process * Add time delay in training loop * Refactor file uploader and add data generation functionality * not complited changes for streamlit integration * refactor the code in order to represent in the output and save logs in appropriate format * implement progress bar in 'syngen/streamlit_app.py' * refactor 'streamlit_app.py' * minor changes * add checkbox 'Create the accuracy report' * refactor 'streamlit_app.py' * Update UI styles * update Dockerfile, 'setup.py' * update Dockerfile, 'requirements.txt', 'setup.cfg' * update 'README.md', revert changes in 'requirements.txt', 'setup.cfg' * Update fonts * refactor 'streamlit_app.py' * delete the folder with a test data * update 'streamlit_app.py' * refactor the code in 'streamlit_app.py' * fix the method 'show_data' of the class StreamlitHandler * temporary changes in 'streamlit_app.py' * refactor 'streamlit_app.py' * add the handling of exceptions in 'streamlit_app.py', add the new page 'Download artifacts' * update the list of icons used for the option menu in streamlit sidebar * remove the page 'Download artifacts' * put the buttons for downloading artifacts in the separate container * refactor the class StreamlitHandler * fix the process of collections log messages to log file * refactor the method 'run_separate' of the class VaeInferHandler, add the expander widget in 'streamlit_app.py', update unit tests * update the method 'file_sink' of the class Streamlit, minor refactor of the progress bar * add the logic related to disabling the button 'Generate data' during the running the process of training and inference * refactor 'streamlit_app.py', update 'VERSION' * update 'README.md' * update 'Dockerfile', 'README.md', 'setup.py', refactor the code related to fixing the beuilding process of docker images and the library * refactor the process of building docker image, and syngen library * fix the proccess of building the syngen library * update 'VERSION' * update 'README.md' * add the progress bar to the Basic page * refactor the code related to the previous changes * refactor the code related to the implementation of the progress bar * refactor the code related to the previous changes * refactor the code related to the previous changes * refactor the code related to the previous changes * refactor the process of logging * fix issues raised by 'flake8' * add more log messages for the progress bar, refactor the code related to these changes * refactor the method 'update_progress_bar' of the class BaseTest * refactor 'start.py' * update 'VERSION' * update the log message in the method 'generate_report' of the class Report * refactor the method 'handle' of the class VaeInferHandler, remove the spinner in the Basic UI page, remove the duplication of the success log message in the Basic UI page * update 'Dockerfile', 'setup.cfg' * update 'VERSION' --------- Co-authored-by: Sergio <[email protected]> Co-authored-by: Sergei Kudrenko <[email protected]> Co-authored-by: denischernow <[email protected]>
1 parent 37e13ae commit 4f9e723

File tree

18 files changed

+682
-404
lines changed

18 files changed

+682
-404
lines changed

Dockerfile

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ RUN apt-get update && \
1717

1818
COPY src/ .
1919
COPY src/syngen/streamlit_app/.streamlit syngen/.streamlit
20+
COPY src/syngen/streamlit_app/.streamlit/config.toml /root/.streamlit/config.toml
2021
ENV PYTHONPATH "${PYTHONPATH}:/src/syngen"
21-
ENTRYPOINT ["python3", "-m", "start"]
22+
ENTRYPOINT ["python3", "-m", "start"]

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ install_requires =
6666
ui =
6767
streamlit
6868
streamlit_option_menu
69-
altair<5
69+
altair>5
7070

7171

7272
[options.packages.find]

src/start.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
def parse_args():
77
parser = argparse.ArgumentParser(
8-
description="Run training, inference tasks, or a Streamlit web UI.", add_help=False
8+
description="Run training, inference tasks, or a Streamlit web UI.",
9+
add_help=False,
910
)
1011
parser.add_argument(
1112
"--task", choices=["train", "infer"], help="Task to run: 'train' or 'infer'."
@@ -28,15 +29,18 @@ def main():
2829
# Check if the Streamlit web UI should be launched
2930
if known_args.webui:
3031
# Adjust the path to your Streamlit application script if necessary
31-
command = ["streamlit", "run", "syngen/streamlit_app.py"] + remaining_argv
32+
command = ["streamlit", "run", "syngen/streamlit_app/run.py"] + remaining_argv
3233
elif known_args.task == "train":
3334
# Construct the command to run the training script
3435
command = ["python", "syngen/train.py"] + remaining_argv
3536
elif known_args.task == "infer":
3637
# Construct the command to run the inference script
3738
command = ["python", "syngen/infer.py"] + remaining_argv
3839
else:
39-
print("Unknown command. Use --task=train, --task=infer, or --webui.", file=sys.stderr)
40+
print(
41+
"Unknown command. Use --task=train, --task=infer, or --webui.",
42+
file=sys.stderr,
43+
)
4044
sys.exit(1)
4145

4246
# Run the command with any additional arguments

src/syngen/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.7.2
1+
0.7.3

src/syngen/ml/handlers/handlers.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919

2020
from syngen.ml.vae import * # noqa: F403
2121
from syngen.ml.data_loaders import DataLoader
22-
from syngen.ml.utils import fetch_dataset, check_if_features_assigned, generate_uuid
22+
from syngen.ml.utils import (
23+
fetch_dataset,
24+
check_if_features_assigned,
25+
generate_uuid,
26+
ProgressBarHandler
27+
)
2328
from syngen.ml.context import get_context
24-
from syngen.ml.config import TrainConfig
2529

2630

2731
class AbstractHandler(ABC):
@@ -168,8 +172,9 @@ def __fit_model(self, data: pd.DataFrame):
168172
self.model.batch_size = min(self.batch_size, len(data))
169173

170174
logger.debug(
171-
f"Train model with parameters: epochs={self.epochs}, row_subset={self.row_subset}, "
172-
f"print_report={self.print_report}, drop_null={self.drop_null}, batch_size={self.batch_size}"
175+
f"Train model with parameters: epochs={self.epochs}, "
176+
f"row_subset={self.row_subset}, print_report={self.print_report}, "
177+
f"drop_null={self.drop_null}, batch_size={self.batch_size}"
173178
)
174179

175180
self.model.fit_on_df(
@@ -181,7 +186,9 @@ def __fit_model(self, data: pd.DataFrame):
181186
return
182187

183188
self.model.save_state(self.paths["state_path"])
184-
logger.info("Finished VAE training")
189+
log_message = "Finished VAE training"
190+
logger.info(log_message)
191+
ProgressBarHandler().set_progress(message=log_message)
185192

186193
def __prepare_dir(self):
187194
os.makedirs(self.paths["fk_kde_path"], exist_ok=True)
@@ -434,7 +441,19 @@ def handle(self, **kwargs):
434441
)
435442
logger.info(f"Total of {batch_num} batch(es)")
436443
batches = self.split_by_batches(self.size, batch_num)
437-
prepared_batches = [self.run(batch, self.run_parallel) for batch in batches]
444+
delta = ProgressBarHandler().delta / batch_num
445+
prepared_batches = []
446+
for i, batch in enumerate(batches):
447+
log_message = (f"Data synthesis for the table - '{self.table_name}'. "
448+
f"Generating the batch {i + 1} of {batch_num}")
449+
ProgressBarHandler().set_progress(
450+
progress=ProgressBarHandler().progress + delta,
451+
delta=delta,
452+
message=log_message,
453+
)
454+
logger.info(log_message)
455+
prepared_batch = self.run(batch, self.run_parallel)
456+
prepared_batches.append(prepared_batch)
438457
prepared_data = (
439458
self._concat_slices_with_unique_pk(prepared_batches)
440459
if len(prepared_batches) > 0

src/syngen/ml/metrics/accuracy_test/accuracy_test.py

+106-28
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Utility,
1818
)
1919
from syngen.ml.metrics.utils import transform_to_base64
20-
from syngen.ml.utils import fetch_training_config
20+
from syngen.ml.utils import fetch_training_config, ProgressBarHandler
2121
from syngen.ml.mlflow_tracker import MlflowTracker
2222

2323

@@ -75,13 +75,21 @@ def _log_report_to_mlflow(self, path):
7575
)
7676
pass
7777

78+
@staticmethod
79+
def update_progress_bar(message, delta=0):
80+
ProgressBarHandler().set_progress(
81+
progress=ProgressBarHandler().progress + delta, delta=None, message=message
82+
)
83+
7884
def _get_cleaned_configs(self):
7985
"""
8086
Get cleaned configs for the report
8187
"""
8288
train_config = {
8389
k: v
84-
for k, v in fetch_training_config(self.paths["train_config_pickle_path"]).to_dict().items()
90+
for k, v in fetch_training_config(self.paths["train_config_pickle_path"])
91+
.to_dict()
92+
.items()
8593
if k != "print_report"
8694
}
8795
infer_config = {
@@ -103,90 +111,160 @@ def __init__(
103111
):
104112
super().__init__(original, synthetic, paths, table_name, infer_config)
105113
self.draws_path = f"{self.paths['draws_path']}/accuracy"
106-
self.univariate = UnivariateMetric(self.original, self.synthetic, self.plot_exists, self.draws_path)
107-
self.bivariate = BivariateMetric(self.original, self.synthetic, self.plot_exists, self.draws_path)
108-
self.correlations = Correlations(self.original, self.synthetic, self.plot_exists, self.draws_path)
109-
self.clustering = Clustering(self.original, self.synthetic, self.plot_exists, self.draws_path)
110-
self.utility = Utility(self.original, self.synthetic, self.plot_exists, self.draws_path)
111-
self.acc = JensenShannonDistance(self.original, self.synthetic, self.plot_exists, self.draws_path)
114+
self.univariate = UnivariateMetric(
115+
self.original, self.synthetic, self.plot_exists, self.draws_path
116+
)
117+
self.bivariate = BivariateMetric(
118+
self.original, self.synthetic, self.plot_exists, self.draws_path
119+
)
120+
self.correlations = Correlations(
121+
self.original, self.synthetic, self.plot_exists, self.draws_path
122+
)
123+
self.clustering = Clustering(
124+
self.original, self.synthetic, self.plot_exists, self.draws_path
125+
)
126+
self.utility = Utility(
127+
self.original, self.synthetic, self.plot_exists, self.draws_path
128+
)
129+
self.acc = JensenShannonDistance(
130+
self.original, self.synthetic, self.plot_exists, self.draws_path
131+
)
112132
self._prepare_dir()
113133

114134
def _fetch_metrics(self, **kwargs):
115135
"""
116136
Fetch the main metrics
117137
"""
138+
delta = ProgressBarHandler().delta / 6
139+
140+
self.update_progress_bar("Generation of the accuracy heatmap...")
118141
self.acc.calculate_all(kwargs["categ_columns"])
119142
acc_median = "%.4f" % self.acc.calculate_heatmap_median(self.acc.heatmap)
143+
logger.info(f"Median accuracy is {acc_median}")
144+
self.update_progress_bar("The accuracy heatmap has been generated", delta)
145+
120146
uni_images = dict()
121147
bi_images = dict()
148+
122149
if self.plot_exists:
150+
self.update_progress_bar("Generation of the univariate distributions...")
123151
uni_images = self.univariate.calculate_all(
124152
kwargs["cont_columns"], kwargs["categ_columns"], kwargs["date_columns"]
125153
)
154+
self.update_progress_bar(
155+
"The univariate distributions have been generated", delta
156+
)
157+
158+
self.update_progress_bar("Generation of the bivariate distributions...")
126159
bi_images = self.bivariate.calculate_all(
127160
kwargs["cont_columns"], kwargs["categ_columns"], kwargs["date_columns"]
128161
)
129-
corr_result = self.correlations.calculate_all(kwargs["categ_columns"], kwargs["cont_columns"])
162+
self.update_progress_bar(
163+
"The bivariate distributions have been generated", delta
164+
)
165+
166+
self.update_progress_bar("Generation of the correlations heatmap...")
167+
corr_result = self.correlations.calculate_all(
168+
kwargs["categ_columns"], kwargs["cont_columns"]
169+
)
130170
corr_result = int(corr_result) if corr_result == 0 else abs(corr_result)
171+
logger.info(f"Median of differences of correlations is {round(corr_result, 4)}")
172+
self.update_progress_bar("The correlations heatmap has been generated", delta)
173+
174+
self.update_progress_bar("Generation of the clustering metric...")
131175
clustering_result = "%.4f" % self.clustering.calculate_all(
132176
kwargs["categ_columns"], kwargs["cont_columns"]
133177
)
134-
utility_result = self.utility.calculate_all(kwargs["categ_columns"], kwargs["cont_columns"])
135-
136-
logger.info(f"Median accuracy is {acc_median}")
137-
logger.info(f"Median of differences of correlations is {round(corr_result, 4)}")
138178
logger.info(f"Median clusters homogeneity is {clustering_result}")
179+
self.update_progress_bar("The clustering metric has been calculated", delta)
139180

140-
return acc_median, corr_result, clustering_result, utility_result, uni_images, bi_images
181+
self.update_progress_bar("Generation of the utility metric...")
182+
utility_result = self.utility.calculate_all(
183+
kwargs["categ_columns"], kwargs["cont_columns"]
184+
)
185+
logger.info(f"Median clusters homogeneity is {clustering_result}")
186+
self.update_progress_bar("The utility metric has been calculated", delta)
141187

142-
def _generate_report(
143-
self,
188+
return (
144189
acc_median,
145190
corr_result,
146191
clustering_result,
147192
utility_result,
148193
uni_images,
149-
bi_images
194+
bi_images,
195+
)
196+
197+
def _generate_report(
198+
self,
199+
acc_median,
200+
corr_result,
201+
clustering_result,
202+
utility_result,
203+
uni_images,
204+
bi_images,
150205
):
151206
"""
152207
Generate the report
153208
"""
154-
with open(f"{os.path.dirname(os.path.realpath(__file__))}/accuracy_report.html") as file_:
209+
with open(
210+
f"{os.path.dirname(os.path.realpath(__file__))}/accuracy_report.html"
211+
) as file_:
155212
template = jinja2.Template(file_.read())
156213

157214
draws_acc_path = f"{self.paths['draws_path']}/accuracy"
158-
uni_images = {title: transform_to_base64(path) for title, path in uni_images.items()}
159-
bi_images = {title: transform_to_base64(path) for title, path in bi_images.items()}
215+
uni_images = {
216+
title: transform_to_base64(path) for title, path in uni_images.items()
217+
}
218+
bi_images = {
219+
title: transform_to_base64(path) for title, path in bi_images.items()
220+
}
160221

161222
train_config, infer_config = self._get_cleaned_configs()
162223

163224
html = template.render(
164225
accuracy_value=acc_median,
165-
accuracy_heatmap=transform_to_base64(f"{draws_acc_path}/accuracy_heatmap.svg"),
226+
accuracy_heatmap=transform_to_base64(
227+
f"{draws_acc_path}/accuracy_heatmap.svg"
228+
),
166229
uni_imgs=uni_images,
167-
correlations_heatmap=transform_to_base64(f"{draws_acc_path}/correlations_heatmap.svg"),
230+
correlations_heatmap=transform_to_base64(
231+
f"{draws_acc_path}/correlations_heatmap.svg"
232+
),
168233
correlation_median=corr_result,
169-
clusters_barplot=transform_to_base64(f"{draws_acc_path}/clusters_barplot.svg"),
234+
clusters_barplot=transform_to_base64(
235+
f"{draws_acc_path}/clusters_barplot.svg"
236+
),
170237
clustering_value=clustering_result,
171238
bi_imgs=bi_images,
172-
utility_barplot=transform_to_base64(f"{draws_acc_path}/utility_barplot.svg"),
239+
utility_barplot=transform_to_base64(
240+
f"{draws_acc_path}/utility_barplot.svg"
241+
),
173242
utility_table=utility_result.to_html(),
174243
is_data_available=False if utility_result.empty else True,
175244
table_name=self.table_name,
176245
training_config=train_config,
177246
inference_config=infer_config,
178247
time=datetime.now().strftime("%H:%M:%S %d/%m/%Y"),
179-
round=round
248+
round=round,
180249
)
181250

182-
with open(f"{self.paths['draws_path']}/accuracy_report.html", "w", encoding="utf-8") as f:
251+
with open(
252+
f"{self.paths['draws_path']}/accuracy_report.html", "w", encoding="utf-8"
253+
) as f:
183254
f.write(html)
184255
self._log_report_to_mlflow(f"{self.paths['draws_path']}/accuracy_report.html")
185256
self._remove_artifacts()
186257

187258
def report(self, *args, **kwargs):
188259
metrics = self._fetch_metrics(**kwargs)
189-
acc_median, corr_result, clustering_result, utility_result, uni_images, bi_images = metrics
260+
(
261+
acc_median,
262+
corr_result,
263+
clustering_result,
264+
utility_result,
265+
uni_images,
266+
bi_images,
267+
) = metrics
190268
MlflowTracker().log_metrics(
191269
{
192270
"Utility_avg": utility_result["Synth to orig ratio"].mean(),
@@ -203,5 +281,5 @@ def report(self, *args, **kwargs):
203281
clustering_result,
204282
utility_result,
205283
uni_images,
206-
bi_images
284+
bi_images,
207285
)

0 commit comments

Comments
 (0)