Skip to content

Commit

Permalink
refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Nov 25, 2024
1 parent f90ad12 commit 4b5315c
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.52rc31
0.9.52rc32
38 changes: 34 additions & 4 deletions src/syngen/ml/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,11 @@ def _train_table(self, table, metadata, delta):
batch_size=train_settings["batch_size"],
loader=self.loader
)
self._write_success_message(slugify(table))
self._write_success_file(slugify(table))
self._save_metadata_file()
ProgressBarHandler().set_progress(
delta=delta,
message=f"Training of the table - {table} was completed"
message=f"Training of the table - '{table}' was completed"
)

def __train_tables(
Expand Down Expand Up @@ -361,6 +361,35 @@ def _find_parent_table(self, table):
), None
)

import os
from slugify import slugify

def _check_completion_of_training(self, table_name: str):
"""
Check if the training process of a specific table has been completed.
Args:
table_name (str): The name of the table to check.
Raises:
FileNotFoundError: If the success file does not exist.
ValueError: If the content of the success file does not indicate success.
"""
path_to_success_file = f"model_artifacts/resources/{slugify(table_name)}/message.success"
error_message = (
f"The training of the table - '{table_name}' hasn't been completed. "
"Please, retrain the table."
)

if not os.path.exists(path_to_success_file):
raise FileNotFoundError(error_message)

with open(path_to_success_file, 'r') as file:
content = file.read().strip()

if content != "SUCCESS":
raise ValueError(error_message)

def _infer_table(self, table, metadata, type_of_process, delta, is_nested=False):
"""
Infer process for a single table
Expand Down Expand Up @@ -395,7 +424,7 @@ def _infer_table(self, table, metadata, type_of_process, delta, is_nested=False)
)
ProgressBarHandler().set_progress(
delta=delta,
message=f"Infer process of the table - {table} was completed"
message=f"Infer process of the table - '{table}' was completed"
)
MlflowTracker().end_run()

Expand All @@ -414,6 +443,7 @@ def __infer_tables(

non_surrogate_tables = [table for table in tables if table not in self.divided]
for table in non_surrogate_tables:
self._check_completion_of_training(slugify(table))
self._infer_table(
table=table,
metadata=config_of_tables,
Expand Down Expand Up @@ -455,7 +485,7 @@ def _generate_reports(self):
Report().clear_report()

@staticmethod
def _write_success_message(table_name: str):
def _write_success_file(table_name: str):
"""
Write success message to the '.success' file
"""
Expand Down
14 changes: 14 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,20 @@ def test_metadata_storage():
shutil.rmtree("model_artifacts")


@pytest.fixture
def test_success_file():
path_to_test_dir = "model_artifacts/resources/test-table"
os.makedirs(path_to_test_dir, exist_ok=True)
success_file_path = f"{path_to_test_dir}/message.success"
with open(success_file_path, "w") as f:
f.write("PROGRESS")

yield success_file_path
if os.path.exists(success_file_path):
shutil.rmtree("model_artifacts")



@pytest.fixture
def test_metadata_file():
return {
Expand Down
165 changes: 165 additions & 0 deletions src/tests/unit/test_worker/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import patch, MagicMock

import pytest

from syngen.ml.worker import Worker
from syngen.ml.config import Validator

Expand Down Expand Up @@ -1823,3 +1825,166 @@ def test_launch_train_with_metadata_without_train_settings(
True
)
rp_logger.info(SUCCESSFUL_MESSAGE)



@patch.object(Worker, "_generate_reports")
@patch.object(Worker, "_check_completion_of_training", return_value=None)
@patch.object(Worker, "_infer_table")
@patch.object(Worker, "_collect_metrics_in_infer")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
def test_launch_infer_pretrained_table(
mock_check_existence_of_destination,
mock_validate_metadata,
mock_collect_metrics_in_infer,
mock_infer_table,
mock_check_completion_of_training,
mock_generate_reports,
rp_logger,
):
"""
Test that the inference process has been launched
if the training process of the table has been finished
"""
rp_logger.info(
"Test that the inference process has been launched "
"if the training process of the table has been finished"
)
worker = Worker(
table_name="test_table",
metadata_path=None,
settings={
"size": 300,
"run_parallel": True,
"random_seed": 3,
"reports": ["accuracy"],
"batch_size": 300,
},
log_level="INFO",
type_of_process="infer",
loader=None
)
worker.launch_infer()
mock_check_existence_of_destination.assert_called_once()
mock_validate_metadata.assert_called_once_with("test_table")
mock_infer_table.assert_called_once_with(
table="test_table",
metadata={
"test_table": {
"train_settings": {
"source": None
},
'infer_settings': {
"size": 300,
"run_parallel": True,
"random_seed": 3,
"reports": ["accuracy"],
"batch_size": 300
},
"keys": {},
"format": {}
}
},
type_of_process="infer",
delta=0.25
)
mock_collect_metrics_in_infer.assert_called_once_with(["test_table"])
mock_generate_reports.assert_called_once()


@patch.object(Worker, "_infer_table")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
def test_launch_infer_not_pretrained_table(
mock_check_existence_of_destination,
mock_validate_metadata,
mock_collect_metrics_in_infer,
mock_infer_table,
caplog,
rp_logger,
):
"""
Test that the inference process hasn't been started
in case the training process of the table hasn't been finished,
and the appropriate success file 'message.success' is absent
"""
rp_logger.info(
"Test that the inference process hasn't been started "
"in case the training process of the table hasn't been finished, "
"and the appropriate success file 'message.success' is absent"
)
worker = Worker(
table_name="test_table",
metadata_path=None,
settings={
"size": 300,
"run_parallel": True,
"random_seed": 3,
"reports": ["accuracy"],
"batch_size": 300,
},
log_level="INFO",
type_of_process="infer",
loader=None
)
with pytest.raises(FileNotFoundError):
with caplog.at_level("ERROR"):
worker.launch_infer()
assert (
"The training of the table - 'test-table' hasn't been completed. "
"Please, retrain the table."
in caplog.text
)
mock_infer_table.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)


@patch.object(Worker, "_infer_table")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
def test_launch_infer_not_pretrained_table(
mock_check_existence_of_destination,
mock_validate_metadata,
mock_infer_table,
test_success_file,
caplog,
rp_logger,
):
"""
Test that the inference process hasn't been started
in case the training process of the table hasn't been finished,
and the appropriate success file 'message.success' is present,
but the content of the file doesn't correspond to finished training process
"""
rp_logger.info(
"Test that the inference process hasn't been started "
"in case the training process of the table hasn't been finished, "
"and the appropriate success file 'message.success' is present, "
"but the content of the file doesn't correspond to finished training process"
)
worker = Worker(
table_name="test_table",
metadata_path=None,
settings={
"size": 300,
"run_parallel": True,
"random_seed": 3,
"reports": ["accuracy"],
"batch_size": 300,
},
log_level="INFO",
type_of_process="infer",
loader=None
)
with pytest.raises(ValueError):
with caplog.at_level("ERROR"):
worker.launch_infer()
assert (
"The training of the table - 'test-table' hasn't been completed. "
"Please, retrain the table."
in caplog.text
)
mock_infer_table.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)

0 comments on commit 4b5315c

Please sign in to comment.