diff --git a/README.md b/README.md
index 99e6edb5..8202a718 100644
--- a/README.md
+++ b/README.md
@@ -28,10 +28,10 @@ Otherwise, if you want to install the UI version with streamlit, run:
pip install syngen[ui]
```
-*Note*: see details of the UI usage in the [corresponding section](#ui-web-interface)
+*Note:* see details of the UI usage in the [corresponding section](#ui-web-interface)
-The training and inference processes are separated with two cli entry points. The training one receives paths to the original table, metadata json file or table name and used hyperparameters.
+The training and inference processes are separated with two CLI entry points. The training one receives paths to the original table, metadata json file or table name and used hyperparameters.
To start training with defaults parameters run:
@@ -74,17 +74,33 @@ train --source PATH_TO_ORIGINAL_CSV \
--epochs INT \
--row_limit INT \
--drop_null BOOL \
- --print_report BOOL \
+ --reports STR \
--batch_size INT
```
+*Note:* To specify multiple options for the *--reports* parameter, you need to provide the *--reports* parameter multiple times.
+For example:
+```bash
+train --source PATH_TO_ORIGINAL_CSV \
+ --table_name TABLE_NAME \
+ --reports accuracy \
+ --reports sample
+```
+The accepted values for the parameter "reports":
+ - "none" (default) - no reports will be generated
+ - "accuracy" - generates an accuracy report to measure the quality of synthetic data relative to the original dataset. This report is produced after the completion of the training process, during which a model learns to generate new data. The synthetic data generated for this report is of the same size as the original dataset to reach more accurate comparison.
+ - "sample" - generates a sample report (if original data is sampled, the comparison of distributions of original data and sampled data is provided in the report)
+ - "metrics_only" - outputs the metrics information only to standard output without generation of an accuracy report
+ - "all" - generates both accuracy and sample reports
+Default value is "none".
+
To train one or more tables using a metadata file, you can use the following command:
```bash
train --metadata_path PATH_TO_METADATA_YAML
```
-The parameters which you can set up for training process:
+Parameters that you can set up for training process:
- source – required parameter for training of single table, a path to the file that you want to use as a reference
- table_name – required parameter for training of single table, an arbitrary string to name the directories
@@ -92,7 +108,7 @@ The parameters which you can set up for training process:
- row_limit – a number of rows to train over. A number less than the original table length will randomly subset the specified number of rows
- drop_null – whether to drop rows with at least one missing value
- batch_size – if specified, the training is split into batches. This can save the RAM
-- print_report - whether to generate accuracy and sampling reports. Please note that the sampling report is generated only if the `row_limit` parameter is set.
+- reports - controls the generation of quality reports, might require significant time for big tables (>10000 rows)
- metadata_path – a path to the metadata file containing the metadata
- column_types - might include the section categorical which contains the listed columns defined as categorical by a user
@@ -103,7 +119,7 @@ Requirements for parameters of training process:
* row_limit - data type - integer
* drop_null - data type - boolean, default value - False
* batch_size - data type - integer, must be equal to or more than 1, default value - 32
-* print_report - data type - boolean, default value is False
+* reports - data type - if the value is passed through CLI - string, if the value is passed in the metadata file - string or list, accepted values: "none" (default) - no reports will be generated, "all" - generates both accuracy and sample reports, "accuracy" - generates an accuracy report, "sample" - generates a sample report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Default value is "none". In the metadata file multiple values can be specified as a list of available options ("accuracy", "sample", "metrics_only") to generate multiple types of reports simultaneously, e.g. ["metrics_only", "sample"]
* metadata_path - data type - string
* column_types - data type - dictionary with the key categorical - the list of columns (data type - string)
@@ -117,9 +133,23 @@ infer --size INT \
--run_parallel BOOL \
--batch_size INT \
--random_seed INT \
- --print_report BOOL
+ --reports STR
```
+*Note:* To specify multiple options for the *--reports* parameter, you need to provide the *--reports* parameter multiple times.
+For example:
+```bash
+infer --table_name TABLE_NAME \
+ --reports accuracy \
+ --reports metrics_only
+```
+The accepted values for the parameter "reports":
+ - "none" (default) - no reports will be generated
+ - "accuracy" - generates an accuracy report that compares original and synthetic data patterns to verify the quality of the generated data
+ - "metrics_only" - outputs the metrics information only to standard output without generation of an accuracy report
+ - "all" - generates an accuracy report
+Default value is "none".
+
To generate one or more tables using a metadata file, you can use the following command:
```bash
@@ -133,7 +163,7 @@ The parameters which you can set up for generation process:
- run_parallel – whether to use multiprocessing (feasible for tables > 5000 rows)
- batch_size – if specified, the generation is split into batches. This can save the RAM
- random_seed – if specified, generates a reproducible result
-- print_report – whether to generate accuracy and sampling reports. Please note that the sampling report is generated only if the row_limit parameter is set.
+- reports - controls the generation of quality reports, might require significant time for big generated tables (>10000 rows)
- metadata_path – a path to metadata file
Requirements for parameters of generation process:
@@ -142,13 +172,13 @@ Requirements for parameters of generation process:
* run_parallel - data type - boolean, default value is False
* batch_size - data type - integer, must be equal to or more than 1
* random_seed - data type - integer, must be equal to or more than 0
-* print_report - data type - boolean, default value is False
+* reports - data type - if the value is passed through CLI - string, if the value is passed in the metadata file - string or list, accepted values: "none" (default) - no reports will be generated, "all" - generates an accuracy report, "accuracy" - generates an accuracy report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Default value is "none". In the metadata file multiple values can be specified as a list of available options ("accuracy", "metrics_only") to generate multiple types of reports simultaneously
* metadata_path - data type - string
The metadata can contain any of the arguments above for each table. If so, the duplicated arguments from the CLI
will be ignored.
-Note: If you want to set the logging level, you can use the parameter log_level in the CLI call:
+*Note:* If you want to set the logging level, you can use the parameter log_level in the CLI call:
```bash
train --source STR --table_name STR --log_level STR
@@ -159,7 +189,6 @@ infer --metadata_path STR --log_level STR
where log_level might be one of the following values: TRACE, DEBUG, INFO, WARNING, ERROR, CRITICAL.
-
### Linked tables generation
To generate one or more tables, you might provide metadata in yaml format. By providing information about the relationships
@@ -167,7 +196,7 @@ between tables via metadata, it becomes possible to manage complex relationships
You can also specify additional parameters needed for training and inference in the metadata file and in this case,
they will be ignored in the CLI call.
-Note: By using metadata file, you can also generate tables with absent relationships.
+*Note:* By using metadata file, you can also generate tables with absent relationships.
In this case, the tables will be generated independently.
The yaml metadata file should match the following template:
@@ -179,15 +208,14 @@ global: # Global settings. Optional paramete
drop_null: False # Drop rows with NULL values. Optional parameter
row_limit: null # Number of rows to train over. A number less than the original table length will randomly subset the specified rows number. Optional parameter
batch_size: 32 # If specified, the training is split into batches. This can save the RAM. Optional parameter
- print_report: False # Turn on or turn off generation of the report. Optional parameter
+ reports: none # Controls the generation of quality reports. Optional parameter. Accepted values: "none" (default) - no reports will be generated, "all" - generates both accuracy and sample reports, "accuracy" - generates an accuracy report, "sample" - generates a sample report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Multiple values can be specified as a list to generate multiple types of reports simultaneously, e.g. ["metrics_only", "sample"]. Might require significant time for big tables (>10000 rows).
infer_settings: # Settings for infer process. Optional parameter
size: 100 # Size for generated data. Optional parameter
run_parallel: False # Turn on or turn off parallel training process. Optional parameter
- print_report: False # Turn on or turn off generation of the report. Optional parameter
+ reports: none # Controls the generation of quality reports. Optional parameter. Accepted values: "none" (default) - no reports will be generated, "all" - generates an accuracy report, "accuracy" - generates an accuracy report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Multiple values can be specified as a list to generate multiple types of reports simultaneously. Might require significant time for big generated tables (>10000 rows).
batch_size: null # If specified, the generation is split into batches. This can save the RAM. Optional parameter
random_seed: null # If specified, generates a reproducible result. Optional parameter
- get_infer_metrics: False # Whether to fetch metrics for the inference process. If the parameter 'print_report' is set to True, the 'get_infer_metrics' parameter will be ignored and metrics will be fetched anyway. Optional parameter
CUSTOMER: # Table name. Required parameter
train_settings: # Settings for training process. Required parameter
@@ -196,7 +224,7 @@ CUSTOMER: # Table name. Required parameter
drop_null: False # Drop rows with NULL values. Optional parameter
row_limit: null # Number of rows to train over. A number less than the original table length will randomly subset the specified rows number. Optional parameter
batch_size: 32 # If specified, the training is split into batches. This can save the RAM. Optional parameter
- print_report: False # Turn on or turn off generation of the report. Optional parameter
+ reports: none # Controls the generation of quality reports. Optional parameter. Accepted values: "none" (default) - no reports will be generated, "all" - generates both accuracy and sample reports, "accuracy" - generates an accuracy report, "sample" - generates a sample report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Multiple values can be specified as a list to generate multiple types of reports simultaneously, e.g. ["metrics_only", "sample"]. Might require significant time for big tables (>10000 rows).
column_types:
categorical: # Force listed columns to have categorical type (use dictionary of values). Optional parameter
- gender
@@ -218,10 +246,10 @@ CUSTOMER: # Table name. Required parameter
destination: "./files/generated_data_customer.csv" # The path where the generated data will be stored. If the information about 'destination' isn't specified, by default the synthetic data will be stored locally in '.csv'. Supported formats include local files in '.csv', '.avro' formats. Optional parameter
size: 100 # Size for generated data. Optional parameter
run_parallel: False # Turn on or turn off parallel training process. Optional parameter
- print_report: False # Turn on or turn off generation of the report. Optional parameter
+ reports: none # Controls the generation of quality reports. Optional parameter. Accepted values: "none" (default) - no reports will be generated, "all" - generates an accuracy report, "accuracy" - generates an accuracy report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Multiple values can be specified as a list to generate multiple types of reports simultaneously. Might require significant time for big generated tables (>10000 rows).
batch_size: null # If specified, the generation is split into batches. This can save the RAM. Optional parameter
random_seed: null # If specified, generates a reproducible result. Optional parameter
- get_infer_metrics: False # Whether to fetch metrics for the inference process. If the parameter 'print_report' is set to True, the 'get_infer_metrics' parameter will be ignored and metrics will be fetched anyway. Optional parameter
+
keys: # Keys of the table. Optional parameter
PK_CUSTOMER_ID: # Name of a key. Only one PK per table.
type: "PK" # The key type. Supported: PK - primary key, FK - foreign key, TKN - token key
@@ -261,20 +289,19 @@ ORDER: # Table name. Required parameter
drop_null: False # Drop rows with NULL values. Optional parameter
row_limit: null # Number of rows to train over. A number less than the original table length will randomly subset the specified rows number. Optional parameter
batch_size: 32 # If specified, the training is split into batches. This can save the RAM. Optional parameter
- print_report: False # Turn on or turn off generation of the report. Optional parameter
+ reports: none # Controls the generation of quality reports. Optional parameter. Accepted values: "none" (default) - no reports will be generated, "all" - generates both accuracy and sample reports, "accuracy" - generates an accuracy report, "sample" - generates a sample report, "metrics_only" - outputs the metrics information only to standard output without generation of a report, e.g. ["metrics_only", "sample"]. Might require significant time for big tables (>10000 rows).
column_types:
- categorical: # Force listed columns to have categorical type (use dictionary of values). Optional parameter
- - gender
- - marital_status
+ categorical: # Force listed columns to have categorical type (use dictionary of values). Optional parameter
+ - gender
+ - marital_status
infer_settings: # Settings for infer process. Optional parameter
destination: "./files/generated_data_order.csv" # The path where the generated data will be stored. If the information about 'destination' isn't specified, by default the synthetic data will be stored locally in '.csv'. Supported formats include local files in 'csv', '.avro' formats. Required parameter
size: 100 # Size for generated data. Optional parameter
run_parallel: False # Turn on or turn off parallel training process. Optional parameter
- print_report: False # Turn on or turn off generation of the report. Optional parameter
+ reports: none # Controls the generation of quality reports. Optional parameter. Accepted values: "none" (default) - no reports will be generated, "all" - generates an accuracy report, "accuracy" - generates an accuracy report, "metrics_only" - outputs the metrics information only to standard output without generation of a report. Multiple values can be specified as a list to generate multiple types of reports simultaneously. Might require significant time for big generated tables (>10000 rows).
batch_size: null # If specified, the generation is split into batches. This can save the RAM. Optional parameter
random_seed: null # If specified, generates a reproducible result. Optional parameter
- get_infer_metrics: False # Whether to fetch metrics for the inference process. If the parameter 'print_report' is set to True, the 'get_infer_metrics' parameter will be ignored and metrics will be fetched anyway. Optional parameter
format: # Settings for reading and writing data in 'csv' format. Optional parameter
sep: ',' # Delimiter to use. Optional parameter
quotechar: '"' # The character used to denote the start and end of a quoted item. Optional parameter
@@ -298,11 +325,11 @@ ORDER: # Table name. Required parameter
- customer_id
references:
table: "CUSTOMER"
- columns:
+ columns:
- customer_id
```
-Note:
+*Note:*
- In the section "global" you can specify training and inference settings for all tables. If the same settings are specified for a specific table, they will override the global settings
- If the information about "destination" isn't specified in "infer_settings", by default the synthetic data will be stored locally in ".csv" format
@@ -325,6 +352,58 @@ infer --metadata_path="./examples/example-metadata/housing_metadata.yaml"
If `--metadata_path` is present and the metadata contains the necessary parameters, other CLI parameters will be ignored.
+### Ways to set the value(s) in the section "reports" of the metadata file
+
+The accepted values in the section "reports" in "train_settings":
+ - "none" (default) - no reports will be generated
+ - "accuracy" - generates an accuracy report to measure the quality of synthetic data relative to the original dataset. This report is produced after the completion of the training process, during which a model learns to generate new data. The synthetic data generated for this report is of the same size as the original dataset to reach more accurate comparison.
+ - "sample" - generates a sample report (if original data is sampled, the comparison of distributions of original data and sampled data is provided in the report)
+ - "metrics_only" - outputs the metrics information only to standard output without generation of an accuracy report
+ - "all" - generates both accuracy and sample reports
+Default value is "none".
+
+Examples how to set the value(s) in the section "reports" in "train_settings":
+```yaml
+
+reports: none
+
+reports: all
+
+reports: accuracy
+
+reports: metrics_only
+
+reports: sample
+
+reports:
+ - accuracy
+ - metrics_only
+ - sample
+```
+The accepted values for the parameter "reports" in "infer_settings":
+ - "none" (default) - no reports will be generated
+ - "accuracy" - generates an accuracy report to verify the quality of the generated data
+ - "metrics_only" - outputs the metrics information only to standard output without generation of an accuracy report
+ - "all" - generates an accuracy report
+Default value is "none".
+
+Examples how to set the value(s) in the section "reports" in "infer_settings":
+```yaml
+
+reports: none
+
+reports: all
+
+reports: accuracy
+
+reports: metrics_only
+
+reports:
+ - accuracy
+ - metrics_only
+```
+
+
### Docker images
The train and inference components of syngen is available as public docker image:
@@ -382,7 +461,7 @@ pip install syngen[ui]
then create a python file and insert the code provided below into it:
```python
-from syngen import streamlit_app
+afrom syngen import streamlit_app
streamlit_app.start()
diff --git a/requirements.txt b/requirements.txt
index 97607508..96235017 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-aiohttp>=3.9.0
+aiohttp>=3.10.11
attrs
avro
base32-crockford
@@ -31,6 +31,7 @@ scipy==1.14.*
seaborn==0.13.*
setuptools==74.1.*
tensorflow==2.15.*
+tornado==6.4.*
tqdm==4.66.3
Werkzeug==3.1.2
xlrd
diff --git a/setup.cfg b/setup.cfg
index 8ce9bb53..6af98198 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -25,7 +25,7 @@ packages = find:
include_package_data = True
python_requires = >3.9, <3.12
install_requires =
- aiohttp>=3.9.0
+ aiohttp>=3.10.11
attrs
avro
base32-crockford
@@ -58,6 +58,7 @@ install_requires =
seaborn==0.13.*
setuptools==74.1.*
tensorflow==2.15.*
+ tornado==6.4.*
tqdm==4.66.3
Werkzeug==3.1.2
xlrd
diff --git a/src/syngen/VERSION b/src/syngen/VERSION
index 1ba936fe..78bc1abd 100644
--- a/src/syngen/VERSION
+++ b/src/syngen/VERSION
@@ -1 +1 @@
-0.9.52
+0.10.0
diff --git a/src/syngen/infer.py b/src/syngen/infer.py
index 2f9c8929..c95c6adc 100644
--- a/src/syngen/infer.py
+++ b/src/syngen/infer.py
@@ -1,5 +1,5 @@
import os
-from typing import Optional
+from typing import Optional, List
import traceback
import click
@@ -11,6 +11,14 @@
set_log_path,
check_if_logs_available
)
+from syngen.ml.utils import validate_parameter_reports
+from syngen.ml.validation_schema import ReportTypes
+
+
+validate_reports = validate_parameter_reports(
+ report_types=ReportTypes().infer_report_types,
+ full_list=ReportTypes().full_list_of_infer_report_types
+)
@click.command()
@@ -48,11 +56,18 @@
"use the same int in this command.",
)
@click.option(
- "--print_report",
- default=False,
- type=click.BOOL,
- help="Whether to print quality report. Might require significant time "
- "for big generated tables (>1000 rows). If absent, it's defaulted to False",
+ "--reports",
+ default=("none",),
+ type=click.UNPROCESSED,
+ multiple=True,
+ callback=validate_reports,
+ help="Controls the generation of quality reports. "
+ "Might require significant time for big generated tables (>10000 rows). "
+ "If set to 'accuracy', generates an accuracy report. "
+ "If set to 'metrics_only', outputs the metrics information "
+ "only to standard output without generation of a report. "
+ "If set to 'all', generates an accuracy report. "
+ "If it's absent or set to 'none', no reports are generated.",
)
@click.option(
"--log_level",
@@ -67,7 +82,7 @@ def launch_infer(
table_name: Optional[str],
run_parallel: bool,
batch_size: Optional[int],
- print_report: bool,
+ reports: List[str],
random_seed: Optional[int],
log_level: str,
):
@@ -80,7 +95,7 @@ def launch_infer(
table_name
run_parallel
batch_size
- print_report
+ reports
random_seed
log_level
-------
@@ -111,9 +126,8 @@ def launch_infer(
"size": size,
"run_parallel": run_parallel,
"batch_size": batch_size,
- "print_report": print_report,
- "random_seed": random_seed,
- "get_infer_metrics": False
+ "reports": reports,
+ "random_seed": random_seed
}
worker = Worker(
table_name=table_name,
diff --git a/src/syngen/ml/config/configurations.py b/src/syngen/ml/config/configurations.py
index f6892fa0..110c049f 100644
--- a/src/syngen/ml/config/configurations.py
+++ b/src/syngen/ml/config/configurations.py
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
-from typing import Optional, Dict, Tuple, Set, List, Callable
+from typing import Optional, Dict, Tuple, Set, List, Callable, Literal
import os
+from copy import deepcopy
import shutil
from datetime import datetime
@@ -23,11 +24,12 @@ class TrainConfig:
drop_null: bool
row_limit: Optional[int]
table_name: Optional[str]
- metadata_path: Optional[str]
- print_report: bool
+ metadata: Dict
+ reports: List[str]
batch_size: int
loader: Optional[Callable[[str], pd.DataFrame]]
data: pd.DataFrame = field(init=False)
+ initial_data_shape: Tuple[int, int] = field(init=False)
paths: Dict = field(init=False)
row_subset: int = field(init=False)
schema: Dict = field(init=False)
@@ -37,7 +39,7 @@ class TrainConfig:
dropped_columns: Set = field(init=False)
def __post_init__(self):
- self.paths = self._get_paths()
+ self._set_paths()
self._remove_existed_artifacts()
self._prepare_dirs()
@@ -59,6 +61,7 @@ def preprocess_data(self):
self._remove_empty_columns()
self._mark_removed_columns()
self._prepare_data()
+ self._check_reports()
def to_dict(self) -> Dict:
"""
@@ -69,7 +72,7 @@ def to_dict(self) -> Dict:
"drop_null": self.drop_null,
"row_subset": self.row_subset,
"batch_size": self.batch_size,
- "print_report": self.print_report
+ "reports": self.reports
}
def _set_batch_size(self):
@@ -78,6 +81,25 @@ def _set_batch_size(self):
"""
self.batch_size = min(self.batch_size, self.row_subset)
+ def _check_sample_report(self):
+ """
+ Check whether it is necessary to generate a certain report
+ """
+ if "sample" in self.reports and self.initial_data_shape[0] == self.row_subset:
+ logger.warning(
+ "The generation of the sample report is unnecessary and won't be produced "
+ "as the source data and sampled data sizes are identical"
+ )
+ reports = deepcopy(self.reports)
+ reports.remove("sample")
+ self.reports = reports
+
+ def _check_reports(self):
+ """
+ Check whether it is necessary to generate a certain report
+ """
+ self._check_sample_report()
+
def _remove_existed_artifacts(self):
"""
Remove existed artifacts from previous train process
@@ -166,6 +188,7 @@ def _extract_data(self):
Extract data and schema necessary for training process
"""
self.data, self.schema = self._load_source()
+ self.initial_data_shape = self.data.shape
self._check_if_data_is_empty()
def _preprocess_data(self):
@@ -243,7 +266,7 @@ def _prepare_data(self):
self._save_input_data()
@slugify_attribute(table_name="slugify_table_name")
- def _get_paths(self) -> Dict:
+ def _set_paths(self):
"""
Create the paths which used in training process
"""
@@ -251,7 +274,7 @@ def _get_paths(self) -> Dict:
f"losses_{self.table_name}_"
f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
- return {
+ self.paths = {
"model_artifacts_path": "model_artifacts/",
"resources_path": f"model_artifacts/resources/{self.slugify_table_name}/",
"tmp_store_path": f"model_artifacts/tmp_store/{self.slugify_table_name}/",
@@ -285,74 +308,93 @@ class InferConfig:
"""
destination: Optional[str]
+ metadata: Dict
+ metadata_path: Optional[str]
size: Optional[int]
table_name: Optional[str]
run_parallel: bool
batch_size: Optional[int]
- metadata_path: Optional[str]
random_seed: Optional[int]
- print_report: bool
- get_infer_metrics: bool
+ reports: List[str]
both_keys: bool
log_level: str
loader: Optional[Callable[[str], pd.DataFrame]]
+ type_of_process: Literal["train", "infer"]
slugify_table_name: str = field(init=False)
def __post_init__(self):
- self.paths = self._get_paths()
- self._set_up_reporting()
+ self._set_paths()
+ self._remove_artifacts()
+ self._set_infer_parameters()
+
+ def _set_infer_parameters(self):
+ self._check_reports()
self._set_up_size()
self._set_up_batch_size()
+ def _remove_reports(self):
+ path_to_reports = self.paths["reports_path"]
+ if os.path.exists(path_to_reports):
+ shutil.rmtree(path_to_reports)
+ logger.info(
+ f"The reports generated in the previous run of an inference process "
+ f"and located in the path - '{path_to_reports}' were removed"
+ )
+
+ def _remove_generated_data(self):
+ default_path_to_synth_data = self.paths["default_path_to_merged_infer"]
+ if os.path.exists(default_path_to_synth_data):
+ os.remove(default_path_to_synth_data)
+ logger.info(
+ f"The synthetic data generated in the previous run of an inference process and "
+ f"located in the path - '{default_path_to_synth_data}' was removed"
+ )
+
+ def _remove_artifacts(self):
+ """
+ Remove artifacts related to the previous generation process
+ """
+ self._remove_reports()
+ self._remove_generated_data()
+
def to_dict(self) -> Dict:
"""
Return the values of the settings of inference process
- :return:
"""
return {
"size": self.size,
"run_parallel": self.run_parallel,
"batch_size": self.batch_size,
"random_seed": self.random_seed,
- "print_report": self.print_report,
- "get_infer_metrics": self.get_infer_metrics,
+ "reports": self.reports,
}
- def _set_up_reporting(self):
+ def _check_required_artifacts(self):
"""
- Check whether it is possible to generate the report
+ Check whether required artifacts exists
"""
if (
- (self.print_report or self.get_infer_metrics)
+ self.reports
and (
- not DataLoader(self.paths["input_data_path"]).has_existed_path
- and not self.loader
+ DataLoader(self.paths["input_data_path"]).has_existed_path is False
+ or self.loader is not None
)
):
- message = (
- f"It seems that the path to the sample of the original data "
- f"of the table '{self.table_name}' - '{self.paths['input_data_path']}' "
- f"doesn't exist."
+ self.reports = list()
+ log_message = (
+ f"It seems that the path to the sample of the original data for the table "
+ f"'{self.table_name}' at '{self.paths['input_data_path']}' does not exist. "
+ f"As a result, no reports for the table '{self.table_name}' will be generated. "
+ f"The 'reports' parameter for the table '{self.table_name}' "
+ f"has been set to 'none'."
)
- logger.warning(message)
- if self.print_report:
- self.print_report = False
- log_message = (
- "As a result, the accuracy report of the table - "
- f"'{self.table_name}' won't be generated. "
- "The parameter '--print_report' of the table - "
- f"'{self.table_name}' has been set to False"
- )
- logger.warning(log_message)
- if self.get_infer_metrics:
- self.get_infer_metrics = False
- log_message = (
- "As a result, the infer metrics related to the table - "
- f"'{self.table_name}' won't be fetched. "
- "The parameter '--get_infer_metrics' of the table - "
- f"'{self.table_name}' has been set to False"
- )
- logger.warning(log_message)
+ logger.warning(log_message)
+
+ def _check_reports(self):
+ """
+ Check whether it is possible to generate reports
+ """
+ self._check_required_artifacts()
def _set_up_size(self):
"""
@@ -379,17 +421,19 @@ def _set_up_batch_size(self):
)
@slugify_attribute(table_name="slugify_table_name")
- def _get_paths(self) -> Dict:
+ def _set_paths(self):
"""
Create the paths which used in inference process
"""
dynamic_name = self.slugify_table_name[:-3] if self.both_keys else self.slugify_table_name
- return {
+ self.paths = {
"original_data_path":
f"model_artifacts/tmp_store/{dynamic_name}/input_data_{dynamic_name}.pkl",
"reports_path": f"model_artifacts/tmp_store/{dynamic_name}/reports",
"input_data_path":
f"model_artifacts/tmp_store/{dynamic_name}/input_data_{dynamic_name}.pkl",
+ "default_path_to_merged_infer": f"model_artifacts/tmp_store/{dynamic_name}/"
+ f"merged_infer_{dynamic_name}.csv",
"path_to_merged_infer": self.destination
if self.destination is not None
else f"model_artifacts/tmp_store/{dynamic_name}/merged_infer_{dynamic_name}.csv",
diff --git a/src/syngen/ml/config/validation.py b/src/syngen/ml/config/validation.py
index 2fa8e4ec..6ab71c0d 100644
--- a/src/syngen/ml/config/validation.py
+++ b/src/syngen/ml/config/validation.py
@@ -8,7 +8,7 @@
from slugify import slugify
from loguru import logger
from syngen.ml.data_loaders import MetadataLoader, DataLoader
-from syngen.ml.validation_schema import ValidationSchema
+from syngen.ml.validation_schema import ValidationSchema, ReportTypes
@dataclass
@@ -22,6 +22,9 @@ class Validator:
type_of_process: Literal["train", "infer"]
validation_source: bool = True
type_of_fk_keys = ["FK"]
+ infer_report_types: List[str] = field(
+ default_factory=lambda: ReportTypes().infer_report_types
+ )
merged_metadata: Dict = field(default_factory=dict)
mapping: Dict = field(default_factory=dict)
existed_columns_mapping: Dict = field(default_factory=dict)
@@ -64,10 +67,13 @@ def _check_conditions(self, metadata: Dict) -> bool:
"""
Check conditions whether to launch validation or not
"""
- print_report = metadata.get("train_settings", {}).get("print_report", False)
+ reports = metadata.get("train_settings", {}).get("reports", [])
return (
self.type_of_process == "infer"
- or (self.type_of_process == "train" and print_report is True)
+ or (
+ self.type_of_process == "train" and
+ any([item in self.infer_report_types for item in reports])
+ )
)
def _validate_metadata(self, table_name: str):
@@ -87,10 +93,10 @@ def _validate_metadata(self, table_name: str):
parent_table = self.mapping[key]["parent_table"]
if parent_table not in self.metadata:
if self._check_conditions(metadata_of_the_table):
- self._check_existence_of_success_file(parent_table)
+ self._check_completion_of_training(parent_table)
self._check_existence_of_generated_data(parent_table)
elif self.type_of_process == "train":
- self._check_existence_of_success_file(parent_table)
+ self._check_completion_of_training(parent_table)
else:
continue
@@ -113,21 +119,6 @@ def _validate_referential_integrity(self, fk_name: str, fk_config: Dict, parent_
)
self.errors["validate referential integrity"][fk_name] = message
- def _check_existence_of_success_file(self, parent_table: str):
- """
- Check if the success file of the certain parent table exists.
- The success file is created after the successful execution of the training process
- of the certain table.
- """
- if not os.path.exists(
- f"model_artifacts/resources/{slugify(parent_table)}/message.success"
- ):
- message = (
- f"The table - '{parent_table}' hasn't been trained completely. "
- f"Please, retrain this table first"
- )
- self.errors["check existence of the success file"][parent_table] = message
-
def _check_existence_of_generated_data(self, parent_table: str):
"""
Check if the generated data of the certain parent table exists.
@@ -182,6 +173,30 @@ def _check_existence_of_destination(self, table_name: str):
)
self.errors["check existence of the destination"][table_name] = message
+ def _check_completion_of_training(self, table_name: str):
+ """
+ Check if the training process of a specific table has been completed.
+
+ Args:
+ table_name (str): The name of the table to check.
+
+ Raises:
+ FileNotFoundError: If the success file does not exist.
+ ValueError: If the content of the success file does not indicate success.
+ """
+ path_to_success_file = f"model_artifacts/resources/{slugify(table_name)}/message.success"
+ error_message = (
+ f"The training of the table - '{table_name}' hasn't been completed. "
+ "Please, retrain the table."
+ )
+
+ if os.path.exists(path_to_success_file):
+ with open(path_to_success_file, "r") as file:
+ content = file.read().strip()
+
+ if not os.path.exists(path_to_success_file) or content != "SUCCESS":
+ self.errors["check completion of the training process"][table_name] = error_message
+
def _check_merged_metadata(self, parent_table: str):
if parent_table not in self.merged_metadata:
message = (
@@ -297,12 +312,13 @@ def _run(self):
for table_name in self.merged_metadata.keys():
self._gather_existed_columns(table_name)
- for table_name in self.merged_metadata.keys():
+ for table_name in self.metadata.keys():
if self.type_of_process == "train" and self.validation_source:
self._check_existence_of_source(table_name)
self._check_existence_of_key_columns(table_name)
self._check_existence_of_referenced_columns(table_name)
elif self.type_of_process == "infer":
+ self._check_completion_of_training(table_name)
self._check_existence_of_destination(table_name)
for table_name in self.metadata.keys():
diff --git a/src/syngen/ml/data_loaders/data_loaders.py b/src/syngen/ml/data_loaders/data_loaders.py
index 2656adc4..74f7128f 100644
--- a/src/syngen/ml/data_loaders/data_loaders.py
+++ b/src/syngen/ml/data_loaders/data_loaders.py
@@ -1,10 +1,11 @@
import os
from pathlib import Path
from abc import ABC, abstractmethod
-from typing import Optional, Dict, Tuple, List
+from typing import Optional, Dict, Tuple, List, Literal
import pickle as pkl
import csv
import inspect
+from dataclasses import dataclass
import pandas as pd
import pandas.errors
@@ -23,6 +24,7 @@
from syngen.ml.validation_schema import (
ExcelFormatSettingsSchema,
CSVFormatSettingsSchema,
+ ReportTypes
)
DELIMITERS = {"\\t": "\t"}
@@ -402,19 +404,44 @@ def save_data(self, metadata: Dict, **kwargs):
self.metadata_loader.save_data(metadata, **kwargs)
+@dataclass
class YAMLLoader(BaseDataLoader):
"""
Class for loading and saving data in YAML format
"""
+ metadata_sections = ["train_settings", "infer_settings", "format", "keys"]
+ train_reports = ReportTypes().full_list_of_train_report_types
+ infer_reports = ReportTypes().full_list_of_infer_report_types
- _metadata_sections = ["train_settings", "infer_settings", "keys"]
+ def __init__(self, path: str):
+ super().__init__(path)
+
+ def _normalize_reports(self, settings: dict, type_of_process: Literal["train", "infer"]):
+ """
+ Cast the value of the parameter 'reports' to the list
+ """
+ reports = settings.get(f"{type_of_process}_settings", {}).get("reports", [])
+ if isinstance(reports, str):
+ if reports not in ["none", "all"]:
+ settings[f"{type_of_process}_settings"]["reports"] = [reports]
+ elif reports == "none":
+ settings[f"{type_of_process}_settings"]["reports"] = []
+ elif reports == "all" and type_of_process == "train":
+ settings[f"{type_of_process}_settings"]["reports"] = self.train_reports
+ elif reports == "all" and type_of_process == "infer":
+ settings[f"{type_of_process}_settings"]["reports"] = self.infer_reports
+
+ def _normalize_parameter_reports(self, metadata: dict) -> dict:
+ for table, settings in metadata.items():
+ self._normalize_reports(settings, "train")
+ self._normalize_reports(settings, "infer")
+ return metadata
def _load_data(self, metadata_file) -> Dict:
try:
metadata = yaml.load(metadata_file, Loader=SafeLoader)
- metadata = self.replace_none_values_of_metadata_settings(
- self._metadata_sections, metadata
- )
+ metadata = self.replace_none_values_of_metadata_settings(metadata)
+ metadata = self._normalize_parameter_reports(metadata)
return metadata
except ScannerError as error:
message = (
@@ -429,8 +456,7 @@ def load_data(self) -> Dict:
with open(self.path, "r", encoding="utf-8") as f:
return self._load_data(f)
- @staticmethod
- def replace_none_values_of_metadata_settings(parameters, metadata: dict):
+ def replace_none_values_of_metadata_settings(self, metadata: dict):
"""
Replace None values for parameters in the metadata
"""
@@ -445,7 +471,7 @@ def replace_none_values_of_metadata_settings(parameters, metadata: dict):
for key in metadata.keys():
if key == "global":
continue
- for parameter in parameters:
+ for parameter in self.metadata_sections:
if metadata.get(key).get(parameter) is None:
metadata[key][parameter] = {}
return metadata
diff --git a/src/syngen/ml/handlers/handlers.py b/src/syngen/ml/handlers/handlers.py
index daa4e082..4e6d952f 100644
--- a/src/syngen/ml/handlers/handlers.py
+++ b/src/syngen/ml/handlers/handlers.py
@@ -167,7 +167,7 @@ class VaeTrainHandler(BaseHandler):
drop_null: bool = field(kw_only=True)
batch_size: int = field(kw_only=True)
type_of_process: str = field(kw_only=True)
- print_report: bool = field(kw_only=True)
+ reports: List[str] = field(kw_only=True)
def __fit_model(self, data: pd.DataFrame):
logger.info("Start VAE training")
@@ -187,11 +187,12 @@ def __fit_model(self, data: pd.DataFrame):
process="train",
)
self.model.batch_size = min(self.batch_size, len(data))
-
+ list_of_reports = [f'"{report}"' for report in self.reports]
+ list_of_reports = ', '.join(list_of_reports) if list_of_reports else '"none"'
logger.debug(
f"Train model with parameters: epochs={self.epochs}, "
- f"row_subset={self.row_subset}, print_report={self.print_report}, "
- f"drop_null={self.drop_null}, batch_size={self.batch_size}"
+ f"row_subset={self.row_subset}, drop_null={self.drop_null}, "
+ f"batch_size={self.batch_size}, reports - {list_of_reports}"
)
self.model.fit_on_df(epochs=self.epochs)
@@ -220,8 +221,7 @@ class VaeInferHandler(BaseHandler):
size: int = field(kw_only=True)
batch_size: int = field(kw_only=True)
run_parallel: bool = field(kw_only=True)
- print_report: bool = field(kw_only=True)
- get_infer_metrics: bool = field(kw_only=True)
+ reports: List[str] = field(kw_only=True)
wrapper_name: str = field(kw_only=True)
log_level: str = field(kw_only=True)
type_of_process: str = field(kw_only=True)
@@ -467,11 +467,16 @@ def _restore_empty_columns(self, df: pd.DataFrame) -> pd.DataFrame:
def handle(self, **kwargs):
self._prepare_dir()
- logger.debug(
- f"Infer model with parameters: size={self.size}, run_parallel={self.run_parallel}, "
- f"batch_size={self.batch_size}, random_seed={self.random_seed}, "
- f"print_report={self.print_report}, get_infer_metrics={self.get_infer_metrics}"
+ list_of_reports = [f'"{report}"' for report in self.reports]
+ list_of_reports = ', '.join(list_of_reports) if list_of_reports else '"none"'
+ log_message = (
+ f"Infer model with parameters: size={self.size}, "
+ f"run_parallel={self.run_parallel}, batch_size={self.batch_size}, "
+ f"random_seed={self.random_seed}"
)
+ if self.type_of_process == "infer":
+ log_message += f", reports - {list_of_reports}"
+ logger.debug(log_message)
logger.info(f"Total of {self.batch_num} batch(es)")
batches = self.split_by_batches()
delta = ProgressBarHandler().delta / self.batch_num
diff --git a/src/syngen/ml/metrics/accuracy_test/accuracy_test.py b/src/syngen/ml/metrics/accuracy_test/accuracy_test.py
index 0eec1eb9..14b8dd74 100644
--- a/src/syngen/ml/metrics/accuracy_test/accuracy_test.py
+++ b/src/syngen/ml/metrics/accuracy_test/accuracy_test.py
@@ -20,6 +20,7 @@
from syngen.ml.metrics.utils import transform_to_base64
from syngen.ml.utils import fetch_config, ProgressBarHandler
from syngen.ml.mlflow_tracker import MlflowTracker
+from syngen.ml.validation_schema import ReportTypes
class BaseTest(ABC):
@@ -36,8 +37,13 @@ def __init__(
self.paths = paths
self.table_name = table_name
self.config = config
- self.plot_exists = (self.config.get("print_report", False)
- or self.config.get("privacy_report", False))
+ self.plot_exists = any(
+ [
+ item in ReportTypes().full_list_of_infer_report_types
+ for item
+ in self.config.get("reports", [])
+ ]
+ )
self.reports_path = str()
@abstractmethod
@@ -89,7 +95,7 @@ def _get_cleaned_configs(self):
"""
Get cleaned configs for the report
"""
- filtered_fields = ["print_report", "get_infer_metrics", "privacy_report"]
+ filtered_fields = ["reports"]
train_config = {
k: v
for k, v in fetch_config(self.paths["train_config_pickle_path"])
diff --git a/src/syngen/ml/mlflow_tracker/mlflow_tracker.py b/src/syngen/ml/mlflow_tracker/mlflow_tracker.py
index 2febd349..11551f50 100644
--- a/src/syngen/ml/mlflow_tracker/mlflow_tracker.py
+++ b/src/syngen/ml/mlflow_tracker/mlflow_tracker.py
@@ -252,7 +252,7 @@ def log_metrics(
def search_runs(self, table_name: str, type_of_process: str):
"""
- Get the list of runs related the certain experment
+ Get the list of runs related the certain experiment
"""
if self.is_active:
run = mlflow.search_runs(
diff --git a/src/syngen/ml/reporters/reporters.py b/src/syngen/ml/reporters/reporters.py
index dc4cfa42..e1022157 100644
--- a/src/syngen/ml/reporters/reporters.py
+++ b/src/syngen/ml/reporters/reporters.py
@@ -103,14 +103,13 @@ def convert_to_timestamp(df, col_name, date_format, na_values):
if d not in na_values else np.NaN for d in df[col_name]
]
- def preprocess_data(self):
+ def preprocess_data(self, original: pd.DataFrame, synthetic: pd.DataFrame):
"""
Preprocess original and synthetic data.
Return original data, synthetic data, float columns, integer columns, categorical columns
without keys columns
"""
types = self.fetch_data_types()
- original, synthetic = self._extract_report_data()
missing_columns = set(original) - set(synthetic)
for col in missing_columns:
synthetic[col] = np.nan
@@ -241,45 +240,66 @@ def _group_reporters(cls):
return grouped_reporters
@classmethod
- def generate_report(cls, type_of_process: str):
+ def generate_report(cls):
"""
Generate all needed reports
"""
grouped_reporters = cls._group_reporters()
+ if grouped_reporters:
+ logger.warning("The report(s) generation might be time-consuming")
+
for table_name, reporters in grouped_reporters.items():
- MlflowTracker().start_run(
- run_name=f"{table_name}-REPORT",
- tags={"table_name": table_name, "process": "report"},
- )
+ cls._start_mlflow_run(table_name)
delta = 0.25 / len(reporters)
+
for reporter in reporters:
- message = (f"The calculation of {reporter.__class__.report_type} metrics "
- f"for the table - '{reporter.table_name}' has started")
- ProgressBarHandler().set_progress(delta=delta, message=message)
- reporter.report()
- if reporter.config["print_report"]:
- message = (f"The {reporter.__class__.report_type} report of the table - "
- f"'{reporter.table_name}' has been generated")
- logger.info(
- f"The {reporter.__class__.report_type} report of the table - "
- f"'{reporter.table_name}' has been generated"
- )
- ProgressBarHandler().set_progress(
- progress=ProgressBarHandler().progress + delta,
- delta=delta,
- message=message
- )
-
- if (
- not reporter.config["print_report"]
- and reporter.config.get("get_infer_metrics") is not None
- ):
- logger.info(
- f"The metrics for the table - '{reporter.table_name}' have been evaluated"
- )
+ cls._launch_reporter(reporter, delta)
+
MlflowTracker().end_run()
+ @staticmethod
+ def _start_mlflow_run(table_name: str):
+ MlflowTracker().start_run(
+ run_name=f"{table_name}-REPORT",
+ tags={"table_name": table_name, "process": "report"},
+ )
+
+ @classmethod
+ def _launch_reporter(cls, reporter, delta: float):
+ cls._log_and_update_progress(
+ delta,
+ f"The calculation of {reporter.__class__.report_type} metrics for the table - "
+ f"'{reporter.table_name}' has started"
+ )
+
+ reporter.report()
+
+ if (
+ reporter.__class__.report_type == "accuracy"
+ and "accuracy" not in reporter.config["reports"]
+ and "metrics_only" in reporter.config["reports"]
+ ):
+ message = (
+ f"The metrics for the table - '{reporter.table_name}' have been evaluated"
+ )
+ else:
+ message = (
+ f"The {reporter.__class__.report_type} report of the table - "
+ f"'{reporter.table_name}' has been generated"
+ )
+ cls._log_and_update_progress(delta, message)
+
+ @staticmethod
+ def _log_and_update_progress(delta: float, message: str):
+ ProgressBarHandler().set_progress(delta=delta, message=message)
+ logger.info(message)
+ ProgressBarHandler().set_progress(
+ progress=ProgressBarHandler().progress + delta,
+ delta=delta,
+ message=message
+ )
+
@property
def reporters(self) -> Dict[str, Reporter]:
return self._reporters
@@ -287,7 +307,7 @@ def reporters(self) -> Dict[str, Reporter]:
class AccuracyReporter(Reporter):
"""
- Reporter for running accuracy test
+ Reporter for running an accuracy test
"""
report_type = "accuracy"
@@ -296,6 +316,7 @@ def report(self):
"""
Run the report
"""
+ original, synthetic = self._extract_report_data()
(
original,
synthetic,
@@ -303,7 +324,7 @@ def report(self):
int_columns,
categorical_columns,
date_columns,
- ) = self.preprocess_data()
+ ) = self.preprocess_data(original, synthetic)
accuracy_test = AccuracyTest(
original,
synthetic,
@@ -320,7 +341,7 @@ def report(self):
class SampleAccuracyReporter(Reporter):
"""
- Reporter for running accuracy test
+ Reporter for running a sample test
"""
report_type = "sample"
@@ -334,6 +355,7 @@ def report(self):
"""
Run the report
"""
+ original, sampled = self._extract_report_data()
(
original,
sampled,
@@ -341,7 +363,7 @@ def report(self):
int_columns,
categorical_columns,
date_columns,
- ) = self.preprocess_data()
+ ) = self.preprocess_data(original, sampled)
accuracy_test = SampleAccuracyTest(
original,
sampled,
diff --git a/src/syngen/ml/strategies/strategies.py b/src/syngen/ml/strategies/strategies.py
index b1e58f54..aca9457d 100644
--- a/src/syngen/ml/strategies/strategies.py
+++ b/src/syngen/ml/strategies/strategies.py
@@ -23,8 +23,6 @@ class Strategy(ABC):
def __init__(self):
self.handler = None
self.config = None
- self.metadata = None
- self.table_name = None
@abstractmethod
def run(self, *args, **kwargs):
@@ -45,17 +43,6 @@ def add_reporters(self):
"""
pass
- def set_metadata(self, metadata):
- if metadata:
- self.metadata = metadata
- return self
- if self.config.table_name:
- metadata = {"table_name": self.config.table_name}
- self.metadata = metadata
- return self
- else:
- raise AttributeError("Either table name or path to metadata MUST be provided")
-
class TrainStrategy(Strategy, ABC):
"""
@@ -82,14 +69,14 @@ def add_handler(self):
Set up the handler which used in training process
"""
root_handler = RootHandler(
- metadata=self.metadata,
+ metadata=self.config.metadata,
table_name=self.config.table_name,
paths=self.config.paths,
loader=self.config.loader
)
vae_handler = VaeTrainHandler(
- metadata=self.metadata,
+ metadata=self.config.metadata,
table_name=self.config.table_name,
schema=self.config.schema,
paths=self.config.paths,
@@ -98,12 +85,12 @@ def add_handler(self):
row_subset=self.config.row_subset,
drop_null=self.config.drop_null,
batch_size=self.config.batch_size,
- print_report=self.config.print_report,
+ reports=self.config.reports,
type_of_process="train",
)
long_text_handler = LongTextsHandler(
- metadata=self.metadata,
+ metadata=self.config.metadata,
table_name=self.config.table_name,
schema=self.config.schema,
paths=self.config.paths,
@@ -116,14 +103,9 @@ def add_handler(self):
def add_reporters(self, **kwargs):
table_name = self.config.table_name
- source = self.config.paths["source_path"]
- loader = self.config.loader
if (
not table_name.endswith("_fk")
- and source is not None
- and loader is None
- and os.path.exists(source)
- and self.config.print_report
+ and "sample" in self.config.reports
):
sample_reporter = SampleAccuracyReporter(
table_name=get_initial_table_name(table_name),
@@ -147,19 +129,8 @@ def run(self, **kwargs):
run_name=f"{table}-PREPROCESS",
tags={"table_name": table, "process": "preprocess"},
)
- self.set_config(
- source=kwargs["source"],
- epochs=kwargs["epochs"],
- drop_null=kwargs["drop_null"],
- row_limit=kwargs["row_limit"],
- table_name=table,
- metadata_path=kwargs["metadata_path"],
- print_report=kwargs["print_report"],
- batch_size=kwargs["batch_size"],
- loader=kwargs["loader"]
- )
-
- self.add_reporters().set_metadata(kwargs["metadata"]).add_handler()
+ self.set_config(**kwargs)
+ self.add_reporters().add_handler()
self.handler.handle()
# End the separate run for the training stage
MlflowTracker().end_run()
@@ -190,10 +161,9 @@ def add_handler(self, type_of_process: str):
"""
Set up the handler which used in infer process
"""
-
self.handler = VaeInferHandler(
- metadata=self.metadata,
metadata_path=self.config.metadata_path,
+ metadata=self.config.metadata,
table_name=self.config.table_name,
paths=self.config.paths,
wrapper_name=VanillaVAEWrapper.__name__,
@@ -201,8 +171,7 @@ def add_handler(self, type_of_process: str):
random_seed=self.config.random_seed,
batch_size=self.config.batch_size,
run_parallel=self.config.run_parallel,
- print_report=self.config.print_report,
- get_infer_metrics=self.config.get_infer_metrics,
+ reports=self.config.reports,
log_level=self.config.log_level,
type_of_process=type_of_process,
loader=self.config.loader
@@ -213,7 +182,7 @@ def add_reporters(self):
table_name = self.config.table_name
if (
not table_name.endswith("_fk") and
- (self.config.print_report or self.config.get_infer_metrics)
+ any([item in ["accuracy", "metrics_only"] for item in self.config.reports])
):
accuracy_reporter = AccuracyReporter(
table_name=get_initial_table_name(table_name),
@@ -229,37 +198,21 @@ def run(self, **kwargs):
"""
Launch the infer process
"""
+ table_name = kwargs["table_name"]
try:
-
- self.set_config(
- destination=kwargs["destination"],
- size=kwargs["size"],
- table_name=kwargs["table_name"],
- metadata_path=kwargs["metadata_path"],
- run_parallel=kwargs["run_parallel"],
- batch_size=kwargs["batch_size"],
- random_seed=kwargs["random_seed"],
- print_report=kwargs["print_report"],
- get_infer_metrics=kwargs["get_infer_metrics"],
- log_level=kwargs["log_level"],
- both_keys=kwargs["both_keys"],
- loader=kwargs["loader"]
- )
-
+ self.set_config(**kwargs)
MlflowTracker().log_params(self.config.to_dict())
-
- self.add_reporters().set_metadata(kwargs["metadata"]).add_handler(
- type_of_process=kwargs["type_of_process"]
- )
+ self.add_reporters()
+ self.add_handler(type_of_process=kwargs["type_of_process"])
self.handler.handle()
except Exception:
logger.error(
- f"Generation of the table - \"{kwargs['table_name']}\" failed on running stage.\n"
+ f"Generation of the table - \"{table_name}\" failed on running stage.\n"
f"The traceback of the error - {traceback.format_exc()}"
)
raise
else:
logger.info(
- f"Synthesis of the table - {kwargs['table_name']} was completed. "
+ f"Synthesis of the table - \"{table_name}\" was completed. "
f"Synthetic data saved in {self.handler.paths['path_to_merged_infer']}"
)
diff --git a/src/syngen/ml/utils/__init__.py b/src/syngen/ml/utils/__init__.py
index 64e390a0..d18e2193 100644
--- a/src/syngen/ml/utils/__init__.py
+++ b/src/syngen/ml/utils/__init__.py
@@ -22,5 +22,6 @@
ProgressBarHandler,
check_if_logs_available,
get_initial_table_name,
+ validate_parameter_reports,
timing
)
diff --git a/src/syngen/ml/utils/utils.py b/src/syngen/ml/utils/utils.py
index 838a6c35..13e8904e 100644
--- a/src/syngen/ml/utils/utils.py
+++ b/src/syngen/ml/utils/utils.py
@@ -1,9 +1,10 @@
import os
import sys
import re
-from typing import List, Dict, Optional, Union, Set
+from typing import List, Dict, Optional, Union, Set, Callable
from dateutil import parser
from datetime import datetime, timedelta
+import time
import pandas as pd
import numpy as np
@@ -13,7 +14,6 @@
from ulid import ULID
import random
from loguru import logger
-import time
MAX_ALLOWED_TIME_MS = 253402214400
MIN_ALLOWED_TIME_MS = -62135596800
@@ -422,3 +422,38 @@ def wrapper(*args, **kwargs):
)
return result
return wrapper
+
+
+def validate_parameter_reports(report_types: list, full_list: list) -> Callable:
+ """
+ Validate the values of the parameter 'reports'
+ """
+ def validator(ctx, param, value) -> List[str]:
+ input_values = set(value)
+ valid_values: List = ["none", "all"]
+ valid_values.extend(report_types)
+
+ if not input_values.issubset(set(valid_values)):
+ raise ValueError(
+ f"Invalid input: Acceptable values for the parameter '--reports' are "
+ f"{', '.join(valid_values)}."
+ )
+ if "none" in input_values and "all" in input_values:
+ raise ValueError(
+ "Invalid input: The '--reports' parameter cannot be set to both 'none' and 'all'. "
+ "Please provide only one of these options."
+ )
+
+ if "none" in input_values or "all" in input_values:
+ if len(input_values) > 1:
+ raise ValueError(
+ "Invalid input: When '--reports' option is set to 'none' or 'all', "
+ "no other values should be provided."
+ )
+ if value[0] == "all":
+ return full_list
+ if value[0] == "none":
+ return list()
+
+ return list(input_values)
+ return validator
diff --git a/src/syngen/ml/vae/models/dataset.py b/src/syngen/ml/vae/models/dataset.py
index c4a7374c..ebe9d2fd 100644
--- a/src/syngen/ml/vae/models/dataset.py
+++ b/src/syngen/ml/vae/models/dataset.py
@@ -132,12 +132,6 @@ def _preprocess_df(self, excluded_columns: Set[str]):
"""
self._cast_to_numeric(excluded_columns)
self.nan_labels_dict = get_nan_labels(self.df, excluded_columns)
- if self.nan_labels_dict and self.format.get("na_values", []):
- logger.info(
- f"Despite the fact that data loading utilized the 'format' section "
- f"for handling NA values, some values have been detected by the algorithm "
- f"as NA labels in the columns - {self.nan_labels_dict}"
- )
self.df = nan_labels_to_float(self.df, self.nan_labels_dict)
def _preparation_step(self):
@@ -415,6 +409,12 @@ def _common_detection(self):
self._set_uuid_columns()
self._set_long_text_columns()
self._set_email_columns()
+ if self.nan_labels_dict and self.format.get("na_values", []):
+ logger.info(
+ f"Despite the fact that data loading utilized the 'format' section "
+ f"for handling NA values, some values have been detected by the algorithm "
+ f"as NA labels in the columns - {self.nan_labels_dict}"
+ )
def _update_schema(self):
"""
diff --git a/src/syngen/ml/vae/wrappers/wrappers.py b/src/syngen/ml/vae/wrappers/wrappers.py
index bb855a87..ee8e6b69 100644
--- a/src/syngen/ml/vae/wrappers/wrappers.py
+++ b/src/syngen/ml/vae/wrappers/wrappers.py
@@ -94,6 +94,7 @@ def _update_dataset(self):
Update dataset object related to the current process
"""
self.dataset.paths = self.paths
+ self.dataset.metadata = self.metadata
self.dataset.main_process = self.main_process
def _save_dataset(self):
diff --git a/src/syngen/ml/validation_schema/__init__.py b/src/syngen/ml/validation_schema/__init__.py
index e157bdc1..cb018f21 100644
--- a/src/syngen/ml/validation_schema/__init__.py
+++ b/src/syngen/ml/validation_schema/__init__.py
@@ -8,4 +8,5 @@
KeysSchema,
ValidationSchema,
SUPPORTED_EXCEL_EXTENSIONS,
+ ReportTypes
)
diff --git a/src/syngen/ml/validation_schema/validation_schema.py b/src/syngen/ml/validation_schema/validation_schema.py
index 934d1310..ac52c7fe 100644
--- a/src/syngen/ml/validation_schema/validation_schema.py
+++ b/src/syngen/ml/validation_schema/validation_schema.py
@@ -1,6 +1,7 @@
-from typing import Dict, Literal
+from typing import Dict, Literal, List
import json
from pathlib import Path
+from dataclasses import dataclass, field
from marshmallow import (
Schema,
@@ -15,6 +16,36 @@
SUPPORTED_EXCEL_EXTENSIONS = [".xls", ".xlsx"]
+@dataclass
+class ReportTypes:
+ infer_report_types: List[str] = field(default_factory=lambda: ["accuracy", "metrics_only"])
+ train_report_types: List[str] = field(init=False)
+ excluded_reports: List[str] = field(default_factory=lambda: ["metrics_only"])
+ full_list_of_train_report_types: List[str] = field(init=False)
+ full_list_of_infer_report_types: List[str] = field(init=False)
+
+ def __post_init__(self):
+ self.train_report_types = self.infer_report_types + ["sample"]
+ self.full_list_of_train_report_types = self.get_list_of_report_types("train")
+ self.full_list_of_infer_report_types = self.get_list_of_report_types("infer")
+
+ def get_list_of_report_types(self, report_type):
+ """
+ Get the full list of reports that should be generated
+ if the parameter 'reports' sets to 'all'
+ """
+ report_types = (
+ self.train_report_types
+ if report_type == "train"
+ else self.infer_report_types
+ )
+ return [
+ report
+ for report in report_types
+ if report not in self.excluded_reports
+ ]
+
+
class ReferenceSchema(Schema):
table = fields.String(required=True, allow_none=False)
columns = fields.List(fields.String(), required=True, allow_none=False)
@@ -72,11 +103,29 @@ def validate_references(self, data, **kwargs):
class TrainingSettingsSchema(Schema):
+ @staticmethod
+ def validate_reports(x):
+ if any([i in ["all", "none"] for i in x]):
+ raise ValidationError(
+ "The value 'all' or 'none' might not be passed in the list."
+ )
+ if not (
+ isinstance(x, list)
+ and all(
+ isinstance(elem, str)
+ and elem in ReportTypes().train_report_types for elem in x
+ )
+ ):
+ raise ValidationError("Invalid value.")
+
epochs = fields.Integer(validate=validate.Range(min=1), required=False)
drop_null = fields.Boolean(required=False)
row_limit = fields.Integer(validate=validate.Range(min=1), allow_none=True, required=False)
batch_size = fields.Integer(validate=validate.Range(min=1), required=False)
- print_report = fields.Boolean(required=False)
+ reports = fields.Raw(
+ required=False,
+ validate=validate_reports
+ )
class ExtendedRestrictedTrainingSettingsSchema(TrainingSettingsSchema):
@@ -92,13 +141,30 @@ class ExtendedTrainingSettingsSchema(ExtendedRestrictedTrainingSettingsSchema):
class InferSettingsSchema(Schema):
+ @staticmethod
+ def validate_reports(x):
+ if any([i in ["all", "none"] for i in x]):
+ raise ValidationError(
+ "The value 'all' or 'none' might not be passed in the list."
+ )
+ if not (
+ isinstance(x, list)
+ and all(
+ isinstance(elem, str)
+ and elem in ReportTypes().infer_report_types for elem in x
+ )
+ ):
+ raise ValidationError("Invalid value.")
+
destination = fields.String(required=False)
size = fields.Integer(validate=validate.Range(min=1), required=False)
run_parallel = fields.Boolean(required=False)
batch_size = fields.Integer(validate=validate.Range(min=1), allow_none=True, required=False)
random_seed = fields.Integer(validate=validate.Range(min=0), allow_none=True, required=False)
- print_report = fields.Boolean(required=False)
- get_infer_metrics = fields.Boolean(required=False)
+ reports = fields.Raw(
+ required=False,
+ validate=validate_reports
+ )
class CSVFormatSettingsSchema(Schema):
diff --git a/src/syngen/ml/worker/worker.py b/src/syngen/ml/worker/worker.py
index 0f3bc3fc..639ae8e9 100644
--- a/src/syngen/ml/worker/worker.py
+++ b/src/syngen/ml/worker/worker.py
@@ -16,6 +16,7 @@
from syngen.ml.context.context import global_context
from syngen.ml.utils import ProgressBarHandler
from syngen.ml.mlflow_tracker import MlflowTracker
+from syngen.ml.validation_schema import ReportTypes
@define
@@ -135,6 +136,7 @@ def __fetch_metadata(self) -> Dict:
},
"infer_settings": {},
"keys": {},
+ "format": {}
}
}
return metadata
@@ -241,15 +243,20 @@ def _split_pk_fk_metadata(self, config, tables):
return config
@staticmethod
- def _should_generate_report(config_of_tables: Dict, type_of_process: str):
+ def _should_generate_data(
+ config_of_tables: Dict,
+ type_of_process: str,
+ list_of_reports: List[str]
+ ):
"""
- Determine whether reports should be generated based
- on the configurations of the tables
+ Determine whether the synthetic data should be generated
+ in order to generate reports based on it
"""
return any(
[
- config.get(f"{type_of_process}_settings", {}).get("print_report", False)
- for table, config in config_of_tables.items()
+ report in list_of_reports
+ for config in config_of_tables.values()
+ for report in config.get(f"{type_of_process}_settings", {}).get("reports", [])
]
)
@@ -289,16 +296,15 @@ def _train_table(self, table, metadata, delta):
drop_null=train_settings["drop_null"],
row_limit=train_settings["row_limit"],
table_name=table,
- metadata_path=self.metadata_path,
- print_report=train_settings["print_report"],
+ reports=train_settings["reports"],
batch_size=train_settings["batch_size"],
loader=self.loader
)
- self._write_success_message(slugify(table))
+ self._write_success_file(slugify(table))
self._save_metadata_file()
ProgressBarHandler().set_progress(
delta=delta,
- message=f"Training of the table - {table} was completed"
+ message=f"Training of the table - '{table}' was completed"
)
def __train_tables(
@@ -329,8 +335,6 @@ def __train_tables(
type_of_process="train"
)
- self._generate_reports()
-
def _get_surrogate_tables_mapping(self):
"""
Get the mapping of surrogate tables, which end with "_pk" and "_fk",
@@ -375,15 +379,13 @@ def _infer_table(self, table, metadata, type_of_process, delta, is_nested=False)
InferStrategy().run(
destination=settings.get("destination") if type_of_process == "infer" else None,
metadata=metadata,
+ metadata_path=self.metadata_path,
size=settings.get("size") if type_of_process == "infer" else None,
table_name=table,
- metadata_path=self.metadata_path,
run_parallel=settings.get("run_parallel") if type_of_process == "infer" else False,
batch_size=settings.get("batch_size") if type_of_process == "infer" else 1000,
random_seed=settings.get("random_seed") if type_of_process == "infer" else 1,
- print_report=settings["print_report"],
- get_infer_metrics=settings.get("get_infer_metrics")
- if type_of_process == "infer" else False,
+ reports=settings["reports"],
log_level=self.log_level,
both_keys=both_keys,
type_of_process=self.type_of_process,
@@ -391,7 +393,7 @@ def _infer_table(self, table, metadata, type_of_process, delta, is_nested=False)
)
ProgressBarHandler().set_progress(
delta=delta,
- message=f"Infer process of the table - {table} was completed"
+ message=f"Infer process of the table - '{table}' was completed"
)
MlflowTracker().end_run()
@@ -433,8 +435,6 @@ def __infer_tables(
)
MlflowTracker().end_run()
- self._generate_reports()
-
def _collect_integral_metrics(self, tables, type_of_process):
"""
Collect the integral metrics depending on the type of process
@@ -447,11 +447,11 @@ def _generate_reports(self):
"""
Generate reports
"""
- Report().generate_report(type_of_process=self.type_of_process.upper())
+ Report().generate_report()
Report().clear_report()
@staticmethod
- def _write_success_message(table_name: str):
+ def _write_success_file(table_name: str):
"""
Write success message to the '.success' file
"""
@@ -482,7 +482,11 @@ def launch_train(self):
metadata_for_inference,
) = metadata_for_inference
- generation_of_reports = self._should_generate_report(metadata_for_training, "train")
+ generation_of_reports = self._should_generate_data(
+ metadata_for_training,
+ "train",
+ ReportTypes().infer_report_types
+ )
self.__train_tables(
tables_for_training,
@@ -492,6 +496,7 @@ def launch_train(self):
generation_of_reports
)
+ self._generate_reports()
self._collect_metrics_in_train(
tables_for_training,
tables_for_inference,
@@ -515,8 +520,13 @@ def launch_infer(self):
"""
tables, config_of_tables = self._prepare_metadata_for_process(type_of_process="infer")
- generation_of_reports = self._should_generate_report(config_of_tables, "infer")
+ generation_of_reports = self._should_generate_data(
+ config_of_tables,
+ "infer",
+ ReportTypes().infer_report_types
+ )
delta = 0.25 / len(tables) if generation_of_reports else 0.5 / len(tables)
self.__infer_tables(tables, config_of_tables, delta, type_of_process="infer")
+ self._generate_reports()
self._collect_metrics_in_infer(tables)
diff --git a/src/syngen/streamlit_app/handlers/handlers.py b/src/syngen/streamlit_app/handlers/handlers.py
index d732ffc0..d57695a7 100644
--- a/src/syngen/streamlit_app/handlers/handlers.py
+++ b/src/syngen/streamlit_app/handlers/handlers.py
@@ -1,3 +1,4 @@
+from typing import Literal
import os
from datetime import datetime
from queue import Queue
@@ -26,12 +27,12 @@ def __init__(
self,
epochs: int,
size_limit: int,
- print_report: bool,
+ reports: bool,
uploaded_file: UploadedFile
):
self.epochs = epochs
self.size_limit = size_limit
- self.print_report = print_report
+ self.reports = ["accuracy"] if reports else []
self.uploaded_file = uploaded_file
self.file_name = self.uploaded_file.name
self.table_name = os.path.splitext(self.file_name)[0]
@@ -50,15 +51,14 @@ def __init__(
"row_limit": 10000,
"drop_null": False,
"batch_size": 32,
- "print_report": False
+ "reports": []
}
self.infer_settings = {
"size": self.size_limit,
"batch_size": 32,
"run_parallel": False,
"random_seed": None,
- "print_report": self.print_report,
- "get_infer_metrics": False
+ "reports": self.reports,
}
def set_logger(self):
@@ -89,7 +89,7 @@ def file_sink(self, message: str):
log_message = fetch_log_message(message)
log_file.write(log_message + "\n")
- def _get_worker(self, process_type: str):
+ def _get_worker(self, process_type: Literal["train", "infer"]):
"""
Get a Worker object
@@ -177,7 +177,7 @@ def open_report(self):
"""
Open the accuracy report in the iframe
"""
- if os.path.exists(self.path_to_report) and self.print_report:
+ if os.path.exists(self.path_to_report) and self.reports:
with open(self.path_to_report, "r") as report:
report_content = report.read()
with st.expander("View the accuracy report"):
@@ -197,7 +197,7 @@ def generate_buttons(self):
os.getenv("SUCCESS_LOG_FILE", ""),
f"logs_{self.sl_table_name}.log"
)
- if self.print_report:
+ if self.reports:
self.generate_button(
"Download the accuracy report",
self.path_to_report,
diff --git a/src/syngen/streamlit_app/run.py b/src/syngen/streamlit_app/run.py
index 1d3244f6..75fb9175 100644
--- a/src/syngen/streamlit_app/run.py
+++ b/src/syngen/streamlit_app/run.py
@@ -61,11 +61,11 @@ def handle_cross_icon():
st.markdown(css, unsafe_allow_html=True)
@staticmethod
- def _get_streamlit_handler(epochs, size_limit, print_report, uploaded_file):
+ def _get_streamlit_handler(epochs, size_limit, reports, uploaded_file):
"""
Get the Streamlit handler
"""
- return StreamlitHandler(epochs, size_limit, print_report, uploaded_file)
+ return StreamlitHandler(epochs, size_limit, reports, uploaded_file)
def run_basic_page(self):
"""
@@ -102,13 +102,13 @@ def run_basic_page(self):
value=1000,
disabled=get_running_status()
)
- print_report = st.checkbox(
+ reports = st.checkbox(
"Create an accuracy report",
value=False,
- key="print_report",
+ key="reports",
disabled=get_running_status()
)
- handler = self._get_streamlit_handler(epochs, size_limit, print_report, uploaded_file)
+ handler = self._get_streamlit_handler(epochs, size_limit, reports, uploaded_file)
if st.button(
"Generate data",
type="primary",
diff --git a/src/syngen/train.py b/src/syngen/train.py
index 8bf907cb..ab3ed934 100644
--- a/src/syngen/train.py
+++ b/src/syngen/train.py
@@ -1,6 +1,6 @@
import os
import traceback
-from typing import Optional
+from typing import Optional, List
import click
from loguru import logger
@@ -12,10 +12,23 @@
set_log_path,
check_if_logs_available
)
+from syngen.ml.utils import validate_parameter_reports
+from syngen.ml.validation_schema import ReportTypes
+
+
+validate_reports = validate_parameter_reports(
+ report_types=ReportTypes().train_report_types,
+ full_list=ReportTypes().full_list_of_train_report_types
+)
@click.command()
-@click.option("--metadata_path", type=str, default=None, help="Path to the metadata file")
+@click.option(
+ "--metadata_path",
+ type=str,
+ default=None,
+ help="Path to the metadata file"
+)
@click.option(
"--source",
type=str,
@@ -49,11 +62,19 @@
"length will randomly subset the specified rows number",
)
@click.option(
- "--print_report",
- default=False,
- type=click.BOOL,
- help="Whether to print quality report. Might require significant time "
- "for big generated tables (>1000 rows). If absent, it's defaulted to False",
+ "--reports",
+ default=("none",),
+ type=click.UNPROCESSED,
+ multiple=True,
+ callback=validate_reports,
+ help="Controls the generation of quality reports. "
+ "Might require significant time for big generated tables (>1000 rows). "
+ "If set to 'sample', generates a sample report. "
+ "If set to 'accuracy', generates an accuracy report. "
+ "If set to 'metrics_only', outputs the metrics information only to standard output "
+ "without generation of a report. "
+ "If set to 'all', generates both accuracy and sample report. "
+ "If it's absent or set to 'none', no reports are generated.",
)
@click.option(
"--log_level",
@@ -76,7 +97,7 @@ def launch_train(
epochs: int,
drop_null: bool,
row_limit: Optional[int],
- print_report: bool,
+ reports: List[str],
log_level: str,
batch_size: int = 32,
):
@@ -91,7 +112,7 @@ def launch_train(
epochs
drop_null
row_limit
- print_report
+ reports
log_level
batch_size
-------
@@ -147,7 +168,7 @@ def launch_train(
"drop_null": drop_null,
"row_limit": row_limit,
"batch_size": batch_size,
- "print_report": print_report,
+ "reports": reports,
}
worker = Worker(
table_name=table_name,
diff --git a/src/tests/conftest.py b/src/tests/conftest.py
index 8cdd7c88..9c66e3a5 100644
--- a/src/tests/conftest.py
+++ b/src/tests/conftest.py
@@ -121,7 +121,7 @@ def test_metadata_storage():
"table_a": {
"train_settings": {
"source": "path/to/table_a.csv",
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {"destination": "path/to/generated_table_a.csv"},
"keys": {
@@ -139,7 +139,7 @@ def test_metadata_storage():
"table_d": {
"train_settings": {
"source": "path/to/table_d.csv",
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {"destination": "path/to/generated_table_d.csv"},
"keys": {
@@ -154,6 +154,19 @@ def test_metadata_storage():
shutil.rmtree("model_artifacts")
+@pytest.fixture
+def test_success_file():
+ path_to_test_dir = "model_artifacts/resources/test-table"
+ os.makedirs(path_to_test_dir, exist_ok=True)
+ success_file_path = f"{path_to_test_dir}/message.success"
+ with open(success_file_path, "w") as f:
+ f.write("PROGRESS")
+
+ yield success_file_path
+ if os.path.exists(success_file_path):
+ shutil.rmtree("model_artifacts")
+
+
@pytest.fixture
def test_metadata_file():
return {
@@ -162,11 +175,11 @@ def test_metadata_file():
"source": "..\\data\\pk_test.csv",
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
diff --git a/src/tests/unit/config/test_config.py b/src/tests/unit/config/test_config.py
index 4163dbb6..e037848c 100644
--- a/src/tests/unit/config/test_config.py
+++ b/src/tests/unit/config/test_config.py
@@ -2,10 +2,234 @@
import pytest
from unittest.mock import patch
-from syngen.ml.config import TrainConfig
+from syngen.ml.config import TrainConfig, InferConfig
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME
+@pytest.mark.parametrize(
+ "drop_null, row_limit, expected_row_subset, expected_reports",
+ [
+ (True, None, 801, ["accuracy", "sample"]),
+ (False, 100, 100, ["accuracy", "sample"]),
+ (False, None, 1000, ["accuracy"])
+ ]
+)
+@patch.object(TrainConfig, "_save_input_data")
+@patch.object(TrainConfig, "_remove_existed_artifacts")
+@patch.object(TrainConfig, "_prepare_dirs")
+def test_init_train_config(
+ mock_prepare_dirs,
+ mock_remove_existed_artifacts,
+ mock_save_input_data,
+ drop_null,
+ row_limit,
+ expected_row_subset,
+ expected_reports,
+ rp_logger
+):
+ rp_logger.info(
+ "Test the process of initialization of the instance of the class TrainConfig"
+ )
+ path_to_source = f"{DIR_NAME}/unit/config/fixtures/data_types_detection_set.csv"
+ table_name = "test_table"
+ metadata = {
+ "test_table": {
+ "train_settings": {
+ "source": path_to_source
+ }
+ }
+ }
+ train_config = TrainConfig(
+ source=path_to_source,
+ epochs=10,
+ drop_null=drop_null,
+ row_limit=row_limit,
+ table_name=table_name,
+ metadata=metadata,
+ reports=expected_reports,
+ batch_size=32,
+ loader=None
+ )
+ train_config.preprocess_data()
+ assert train_config.source == path_to_source
+ assert train_config.epochs == 10
+ assert train_config.drop_null == drop_null
+ assert train_config.row_limit == row_limit
+ assert train_config.table_name == table_name
+ assert train_config.metadata == metadata
+ assert train_config.reports == expected_reports
+ assert train_config.batch_size == 32
+ assert train_config.loader is None
+ assert train_config.initial_data_shape == (1000, 11)
+ assert train_config.row_subset == expected_row_subset
+ assert train_config.schema == {"fields": {}, "format": "CSV"}
+ assert train_config.original_schema is None
+ assert train_config.slugify_table_name == "test-table"
+ assert train_config.columns == [
+ "id", "first_name", "last_name", "email",
+ "gender", "gender_abbr", "gender_abbr_3",
+ "age", "price", "date", "comments"
+ ]
+ assert train_config.dropped_columns == set()
+
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+def test_init_infer_config_with_absent_input_data_in_train_process(rp_logger):
+ rp_logger.info(
+ "Test the process of initialization of the instance of the class InferConfig "
+ "during the training process in case the input data is absent"
+ )
+ table_name = "test_table"
+ path_to_source = "path/to/source.csv"
+ metadata = {
+ "test_table": {
+ "train_settings": {
+ "source": path_to_source,
+ "reports": ["accuracy"]
+ }
+ }
+ }
+ infer_config = InferConfig(
+ destination="path/to/destination.csv",
+ metadata=metadata,
+ metadata_path="path/to/metadata.yaml",
+ size=100,
+ table_name=table_name,
+ run_parallel=False,
+ batch_size=100,
+ random_seed=None,
+ reports=["accuracy"],
+ both_keys=True,
+ log_level="DEBUG",
+ loader=None,
+ type_of_process="train"
+ )
+ assert infer_config.reports == []
+
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+def test_init_infer_config_with_absent_input_data_in_infer_process(rp_logger):
+ rp_logger.info(
+ "Test the process of initialization of the instance of the class InferConfig "
+ "during the inference process in case the input data is absent"
+ )
+ table_name = "test_table"
+ path_to_source = "path/to/source.csv"
+ metadata = {
+ "test_table": {
+ "train_settings": {
+ "source": path_to_source
+ },
+ "infer_settings": {
+ "reports": ["accuracy"]
+ }
+ }
+ }
+ infer_config = InferConfig(
+ destination="path/to/destination.csv",
+ metadata=metadata,
+ metadata_path="path/to/metadata.yaml",
+ size=100,
+ table_name=table_name,
+ run_parallel=False,
+ batch_size=100,
+ random_seed=None,
+ reports=["accuracy"],
+ both_keys=True,
+ log_level="DEBUG",
+ loader=None,
+ type_of_process="infer"
+ )
+ assert infer_config.reports == []
+
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.fixture
+def test_init_infer_config_with_existed_input_data_in_train_process(mocker, rp_logger):
+ rp_logger.info(
+ "Test the process of initialization of the instance of the class InferConfig "
+ "during the training process in case the input data is present"
+ )
+ table_name = "test_table"
+ path_to_source = "path/to/source.csv"
+
+ metadata = {
+ "test_table": {
+ "train_settings": {
+ "source": path_to_source,
+ "reports": ["accuracy"]
+ }
+ }
+ }
+
+ mocker.patch("syngen.ml.data_loaders.DataLoader.has_existed_path", return_value=True)
+
+ infer_config = InferConfig(
+ destination="path/to/destination.csv",
+ metadata=metadata,
+ metadata_path="path/to/metadata.yaml",
+ size=100,
+ table_name=table_name,
+ run_parallel=False,
+ batch_size=100,
+ random_seed=None,
+ reports=["accuracy"],
+ both_keys=True,
+ log_level="DEBUG",
+ loader=None,
+ type_of_process="train"
+ )
+
+ assert infer_config.reports == ["accuracy"]
+
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.fixture
+def test_init_infer_config_with_existed_input_data_in_infer_process(mocker, rp_logger):
+ rp_logger.info(
+ "Test the process of initialization of the instance of the class InferConfig "
+ "during the inference process in case the input data is present"
+ )
+ table_name = "test_table"
+
+ metadata = {
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/source.csv",
+ },
+ "infer_settings": {
+ "reports": ["accuracy"]
+ }
+ }
+ }
+
+ mocker.patch("syngen.ml.data_loaders.DataLoader.has_existed_path", return_value=True)
+
+ infer_config = InferConfig(
+ destination="path/to/destination.csv",
+ metadata=metadata,
+ metadata_path="path/to/metadata.yaml",
+ size=100,
+ table_name=table_name,
+ run_parallel=False,
+ batch_size=100,
+ random_seed=None,
+ reports=["accuracy"],
+ both_keys=True,
+ log_level="DEBUG",
+ loader=None,
+ type_of_process="infer"
+ )
+
+ assert infer_config.reports == ["accuracy"]
+
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
def test_get_state_of_train_config(rp_logger):
rp_logger.info("Test the method '__getstate__' of the class TrainConfig")
train_config = TrainConfig(
@@ -14,8 +238,14 @@ def test_get_state_of_train_config(rp_logger):
drop_null=True,
row_limit=1000,
table_name="test_table",
- metadata_path="metadata/path.yaml",
- print_report=True,
+ metadata={
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/data.csv"
+ }
+ }
+ },
+ reports=["accuracy", "sample"],
batch_size=32,
loader=lambda x: pd.DataFrame()
)
@@ -25,8 +255,8 @@ def test_get_state_of_train_config(rp_logger):
"drop_null",
"row_limit",
"table_name",
- "metadata_path",
- "print_report",
+ "metadata",
+ "reports",
"batch_size"
}
state = train_config.__getstate__()
@@ -62,8 +292,83 @@ def test_preprocess_data(
drop_null=drop_null,
row_limit=row_limit,
table_name="test_table",
- metadata_path="metadata/path.yaml",
- print_report=True,
+ metadata={
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/data.csv"
+ }
+ }
+ },
+ reports=["accuracy", "sample"],
+ batch_size=32,
+ loader=None
+ )
+ train_config.preprocess_data()
+ mock_save_input_data.assert_called_once()
+ mock_save_original_schema.assert_called_once()
+ mock_mark_removed_columns.assert_called_once()
+ mock_remove_empty_columns.assert_called_once()
+ train_config.row_subset == expected_size
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.mark.parametrize("drop_null, row_limit, expected_size, expected_metadata", [
+ (False, None, 1000, {
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/data.csv",
+ "reports": ["accuracy"]
+ }
+ }
+ }),
+ (True, None, 801, {
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/data.csv",
+ "reports": ["accuracy", "sample"]
+ }
+ }
+ }),
+ (True, 100, 100, {
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/data.csv",
+ "reports": ["accuracy", "sample"]
+ }
+ }
+ })
+])
+@patch.object(TrainConfig, "_save_original_schema")
+@patch.object(TrainConfig, "_remove_empty_columns")
+@patch.object(TrainConfig, "_mark_removed_columns")
+@patch.object(TrainConfig, "_save_input_data")
+def test_check_reports_in_train_config(
+ mock_save_input_data,
+ mock_mark_removed_columns,
+ mock_remove_empty_columns,
+ mock_save_original_schema,
+ drop_null,
+ row_limit,
+ expected_size,
+ expected_metadata,
+ rp_logger
+):
+ rp_logger.info("Test the method '_check_reports' of the class TrainConfig")
+ train_config = TrainConfig(
+ source=f"{DIR_NAME}/unit/config/fixtures/data_types_detection_set.csv",
+ epochs=10,
+ drop_null=drop_null,
+ row_limit=row_limit,
+ table_name="test_table",
+ metadata={
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/data.csv",
+ "reports": ["accuracy", "sample"]
+ }
+ }
+ },
+ reports=["accuracy", "sample"],
batch_size=32,
loader=None
)
@@ -73,4 +378,5 @@ def test_preprocess_data(
mock_mark_removed_columns.assert_called_once()
mock_remove_empty_columns.assert_called_once()
train_config.row_subset == expected_size
+ train_config.metadata = expected_metadata
rp_logger.info(SUCCESSFUL_MESSAGE)
diff --git a/src/tests/unit/data_loaders/fixtures/metadata/metadata.yaml b/src/tests/unit/data_loaders/fixtures/metadata/metadata.yaml
index 81d68e1a..f69dec36 100644
--- a/src/tests/unit/data_loaders/fixtures/metadata/metadata.yaml
+++ b/src/tests/unit/data_loaders/fixtures/metadata/metadata.yaml
@@ -3,14 +3,14 @@ pk_test: # Use table name here
source: ../data/pk_test.csv # Supported formats include cloud storage locations, local files
epochs: 1 # Number of epochs
drop_null: false # Drop rows with NULL values
- print_report: false # Turn on or turn off generation of the report
+ reports: none # Turn on or turn off generation of the report
row_limit: 800
infer_settings: # Settings for infer process
size: 100
run_parallel: false # Turn on or turn off parallel training process
random_seed: 1 # Ensure reproducible tables generation
- print_report: true
+ reports: all
keys: # Keys of the table
pk_id: # Name of a key
diff --git a/src/tests/unit/data_loaders/fixtures/metadata/metadata.yml b/src/tests/unit/data_loaders/fixtures/metadata/metadata.yml
index 81d68e1a..f69dec36 100644
--- a/src/tests/unit/data_loaders/fixtures/metadata/metadata.yml
+++ b/src/tests/unit/data_loaders/fixtures/metadata/metadata.yml
@@ -3,14 +3,14 @@ pk_test: # Use table name here
source: ../data/pk_test.csv # Supported formats include cloud storage locations, local files
epochs: 1 # Number of epochs
drop_null: false # Drop rows with NULL values
- print_report: false # Turn on or turn off generation of the report
+ reports: none # Turn on or turn off generation of the report
row_limit: 800
infer_settings: # Settings for infer process
size: 100
run_parallel: false # Turn on or turn off parallel training process
random_seed: 1 # Ensure reproducible tables generation
- print_report: true
+ reports: all
keys: # Keys of the table
pk_id: # Name of a key
diff --git a/src/tests/unit/data_loaders/test_data_loaders.py b/src/tests/unit/data_loaders/test_data_loaders.py
index 6dccfd9a..98922e2f 100644
--- a/src/tests/unit/data_loaders/test_data_loaders.py
+++ b/src/tests/unit/data_loaders/test_data_loaders.py
@@ -508,16 +508,17 @@ def test_load_metadata_in_yaml_format(rp_logger):
"source": "../data/pk_test.csv",
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
},
"keys": {"pk_id": {"columns": ["Id"], "type": "PK"}},
+ "format": {}
},
}
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -537,16 +538,17 @@ def test_load_metadata_in_yml_format(rp_logger):
"source": "../data/pk_test.csv",
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
},
"keys": {"pk_id": {"columns": ["Id"], "type": "PK"}},
+ "format": {}
},
}
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -563,16 +565,17 @@ def test_load_metadata_by_yaml_loader_in_yaml_format(rp_logger):
"source": "../data/pk_test.csv",
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
},
"keys": {"pk_id": {"columns": ["Id"], "type": "PK"}},
+ "format": {}
}
}
@@ -598,17 +601,18 @@ def test_load_metadata_by_yaml_loader_in_yml_format_without_validation(rp_logger
"train_settings": {
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
},
"keys": {"pk_id": {"columns": ["Id"], "type": "PK"}},
"source": "..\\data\\pk_test.csv",
+ "format": {}
}
}
@@ -634,16 +638,17 @@ def test_save_metadata_in_yaml_format(test_yaml_path, test_metadata_file, rp_log
"source": "..\\data\\pk_test.csv",
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
},
"keys": {"pk_id": {"columns": ["Id"], "type": "PK"}},
+ "format": {}
},
}
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -662,16 +667,17 @@ def test_save_metadata_in_yml_format(test_yml_path, test_metadata_file, rp_logge
"source": "..\\data\\pk_test.csv",
"drop_null": False,
"epochs": 1,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 1,
"run_parallel": False,
"size": 100,
},
"keys": {"pk_id": {"columns": ["Id"], "type": "PK"}},
+ "format": {}
},
}
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -691,6 +697,7 @@ def test_load_metadata_with_none_params_in_yaml_format(rp_logger):
"train_settings": {"source": "../data/pk_test.csv"},
"infer_settings": {},
"keys": {},
+ "format": {}
},
}
rp_logger.info(SUCCESSFUL_MESSAGE)
diff --git a/src/tests/unit/handlers/test_handlers.py b/src/tests/unit/handlers/test_handlers.py
index 09a349dc..bec97d8e 100644
--- a/src/tests/unit/handlers/test_handlers.py
+++ b/src/tests/unit/handlers/test_handlers.py
@@ -34,7 +34,7 @@
],
)
def test_get_pk_path(
- mock_os_path_exists, path_to_metadata, expected_path, type_of_process, rp_logger
+ mock_os_path_exists, path_to_metadata, expected_path, type_of_process, rp_logger
):
"""
Test the method '_get_pk_path' of the class VaeInferHandler
@@ -51,8 +51,7 @@ def test_get_pk_path(
size=100,
batch_size=100,
run_parallel=False,
- print_report=False,
- get_infer_metrics=False,
+ reports=[],
wrapper_name="MMDVAEWrapper",
log_level="INFO",
type_of_process=type_of_process,
@@ -99,8 +98,7 @@ def test_split_by_batches(
size=size,
batch_size=batch_size,
run_parallel=False,
- print_report=False,
- get_infer_metrics=False,
+ reports=[],
wrapper_name="MMDVAEWrapper",
log_level="INFO",
type_of_process="infer",
diff --git a/src/tests/unit/launchers/fixtures/metadata.yaml b/src/tests/unit/launchers/fixtures/metadata.yaml
index f06cc161..fff01bef 100644
--- a/src/tests/unit/launchers/fixtures/metadata.yaml
+++ b/src/tests/unit/launchers/fixtures/metadata.yaml
@@ -3,11 +3,11 @@ test_table: # Use table name here
source: ./tests/unit/launchers/metadata.yaml # Supported formats include cloud storage locations, local files
epochs: 8 # Number of epochs
drop_null: false # Drop rows with NULL values
- print_report: true # Turn on or turn off generation of the report
+ reports: all # Whether to generate accuracy, sampling reports or just fetch accuracy metrics
infer_settings: # Settings for infer process
size: 90 # Size for generated data
- print_report: true # Turn on or turn off generation of the report
+ reports: all # Whether to generate an accuracy report or just fetch accuracy metrics
keys: # Keys of the table
pk_pk_tst: # Name of a key
diff --git a/src/tests/unit/launchers/test_launch_infer.py b/src/tests/unit/launchers/test_launch_infer.py
index 90f97020..743812ed 100644
--- a/src/tests/unit/launchers/test_launch_infer.py
+++ b/src/tests/unit/launchers/test_launch_infer.py
@@ -1,19 +1,22 @@
from unittest.mock import patch
+import pytest
from click.testing import CliRunner
from syngen.infer import launch_infer
from syngen.ml.worker import Worker
+from syngen.ml.validation_schema import ReportTypes
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME
TABLE_NAME = "test_table"
PATH_TO_METADATA = f"{DIR_NAME}/unit/launchers/fixtures/metadata.yaml"
+INFER_REPORT_TYPES = ReportTypes().infer_report_types
@patch.object(Worker, "launch_infer")
@patch.object(Worker, "__attrs_post_init__")
def test_infer_table_with_table_name(
- mock_post_init, mock_launch_infer, rp_logger
+ mock_post_init, mock_launch_infer, rp_logger
):
rp_logger.info("Launch infer process through CLI with parameter '--table_name'")
runner = CliRunner()
@@ -27,7 +30,7 @@ def test_infer_table_with_table_name(
@patch.object(Worker, "launch_infer")
@patch.object(Worker, "__attrs_post_init__")
def test_infer_table_with_metadata_path(
- mock_post_init, mock_launch_infer, rp_logger
+ mock_post_init, mock_launch_infer, rp_logger
):
rp_logger.info("Launch infer process through CLI with parameter '--metadata_path'")
runner = CliRunner()
@@ -124,7 +127,9 @@ def test_infer_table_with_invalid_run_parallel(rp_logger):
"Launch infer process through CLI with invalid 'run_parallel' parameter equals 'test'"
)
runner = CliRunner()
- result = runner.invoke(launch_infer, ["--run_parallel", "test", "--table_name", TABLE_NAME])
+ result = runner.invoke(
+ launch_infer, ["--run_parallel", "test", "--table_name", TABLE_NAME]
+ )
assert result.exit_code == 2
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -158,7 +163,7 @@ def test_infer_table_with_invalid_batch_size(rp_logger):
@patch.object(Worker, "launch_infer")
@patch.object(Worker, "__attrs_post_init__")
def test_infer_table_with_valid_random_seed(
- mock_post_init, mock_launch_infer, rp_logger
+ mock_post_init, mock_launch_infer, rp_logger
):
rp_logger.info(
"Launch infer process through CLI with valid 'random_seed' parameter equals 1"
@@ -181,27 +186,104 @@ def test_infer_table_with_invalid_random_seed(rp_logger):
rp_logger.info(SUCCESSFUL_MESSAGE)
+@pytest.mark.parametrize("valid_value", INFER_REPORT_TYPES + ["none", "all"])
@patch.object(Worker, "launch_infer")
@patch.object(Worker, "__attrs_post_init__")
-def test_infer_table_with_valid_print_report(
- mock_post_init, mock_launch_infer, rp_logger
+def test_infer_table_with_valid_parameter_reports(
+ mock_post_init, mock_launch_infer, valid_value, rp_logger
):
rp_logger.info(
- "Launch infer process through CLI with valid 'print_report' parameter equals True"
+ f"Launch infer process through CLI with valid 'reports' parameter equals '{valid_value}'"
)
runner = CliRunner()
- result = runner.invoke(launch_infer, ["--print_report", True, "--table_name", TABLE_NAME])
+ result = runner.invoke(
+ launch_infer, ["--reports", valid_value, "--table_name", TABLE_NAME]
+ )
assert result.exit_code == 0
mock_post_init.assert_called_once()
mock_launch_infer.assert_called_once()
rp_logger.info(SUCCESSFUL_MESSAGE)
-def test_infer_table_with_invalid_print_report(rp_logger):
+@pytest.mark.parametrize(
+ "first_value, second_value",
+ [
+ (pv, i) for pv in INFER_REPORT_TYPES
+ for i in INFER_REPORT_TYPES
+ ]
+)
+@patch.object(Worker, "launch_infer")
+@patch.object(Worker, "__attrs_post_init__")
+def test_infer_table_with_several_valid_parameter_reports(
+ mock_post_init, mock_launch_infer, first_value, second_value, rp_logger
+):
rp_logger.info(
- "Launch infer process through CLI with invalid 'print_report' parameter equals 'test'"
+ f"Launch infer process through CLI "
+ f"with several valid 'reports' parameters equals '{first_value}' and '{second_value}'"
)
runner = CliRunner()
- result = runner.invoke(launch_infer, ["--print_report", "test", "--table_name", TABLE_NAME])
- assert result.exit_code == 2
+ result = runner.invoke(
+ launch_infer,
+ [
+ "--reports",
+ first_value,
+ "--reports",
+ second_value,
+ "--table_name",
+ TABLE_NAME,
+ ],
+ )
+ mock_post_init.assert_called_once()
+ mock_launch_infer.assert_called_once()
+ assert result.exit_code == 0
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.mark.parametrize("invalid_value", [
+ "sample", "test", ("none", "all"), ("none", "test"), ("all", "test")
+])
+def test_infer_table_with_invalid_parameter_reports(invalid_value, rp_logger):
+ rp_logger.info(
+ f"Launch infer process through CLI "
+ f"with invalid 'reports' parameter equals '{invalid_value}'"
+ )
+ runner = CliRunner()
+ result = runner.invoke(
+ launch_infer, ["--reports", invalid_value, "--table_name", TABLE_NAME]
+ )
+ assert result.exit_code == 1
+ assert isinstance(result.exception, ValueError)
+ assert result.exception.args == (
+ "Invalid input: Acceptable values for the parameter '--reports' "
+ "are none, all, accuracy, metrics_only.",
+ )
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.mark.parametrize(
+ "prior_value, value",
+ [(pv, i) for pv in ["all", "none"] for i in INFER_REPORT_TYPES]
+)
+def test_infer_table_with_redundant_parameter_reports(prior_value, value, rp_logger):
+ rp_logger.info(
+ f"Launch infer process through CLI with redundant 'reports' parameter: '{value}'"
+ )
+ runner = CliRunner()
+ result = runner.invoke(
+ launch_infer, [
+ "--reports",
+ prior_value,
+ "--reports",
+ value,
+ "--table_name",
+ TABLE_NAME
+ ]
+ )
+ assert result.exit_code == 1
+ assert isinstance(result.exception, ValueError)
+ assert result.exception.args == (
+ "Invalid input: When '--reports' option is set to 'none' or 'all', "
+ "no other values should be provided.",)
+ rp_logger.info(SUCCESSFUL_MESSAGE)
rp_logger.info(SUCCESSFUL_MESSAGE)
diff --git a/src/tests/unit/launchers/test_launch_train.py b/src/tests/unit/launchers/test_launch_train.py
index 9a4ae00d..a98b735f 100644
--- a/src/tests/unit/launchers/test_launch_train.py
+++ b/src/tests/unit/launchers/test_launch_train.py
@@ -1,25 +1,31 @@
from unittest.mock import patch
+import pytest
from click.testing import CliRunner
from syngen.train import launch_train
from syngen.ml.worker import Worker
+from syngen.ml.validation_schema import ReportTypes
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME
TABLE_NAME = "test_table"
PATH_TO_TABLE = f"{DIR_NAME}/unit/launchers/fixtures/table_with_data.csv"
PATH_TO_METADATA = f"{DIR_NAME}/unit/launchers/fixtures/metadata.yaml"
+TRAIN_REPORT_TYPES = ReportTypes().train_report_types
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
def test_train_table_with_source_and_table_name(
- mock_post_init, mock_launch_train, rp_logger
+ mock_post_init, mock_launch_train, rp_logger
):
rp_logger.info(
"Launch train process through CLI with parameters '--source' and '--table_name'"
)
runner = CliRunner()
- result = runner.invoke(launch_train, ["--source", PATH_TO_TABLE, "--table_name", TABLE_NAME])
+ result = runner.invoke(
+ launch_train,
+ ["--source", PATH_TO_TABLE, "--table_name", TABLE_NAME]
+ )
assert result.exit_code == 0
mock_post_init.assert_called_once()
mock_launch_train.assert_called_once()
@@ -29,7 +35,7 @@ def test_train_table_with_source_and_table_name(
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
def test_train_table_with_metadata_path(
- mock_post_init, mock_launch_train, rp_logger
+ mock_post_init, mock_launch_train, rp_logger
):
rp_logger.info("Launch train process through CLI with parameters '--metadata_path'")
runner = CliRunner()
@@ -169,7 +175,7 @@ def test_train_table_without_parameters(rp_logger):
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
def test_train_table_with_valid_epochs(
- mock_post_init, mock_launch_train, monkeypatch, rp_logger
+ mock_post_init, mock_launch_train, monkeypatch, rp_logger
):
rp_logger.info(
"Launch train process through CLI with valid 'epochs' parameter"
@@ -201,7 +207,7 @@ def test_train_table_with_invalid_epochs(rp_logger):
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
def test_train_table_with_valid_drop_null(
- mock_post_init, mock_launch_train, rp_logger
+ mock_post_init, mock_launch_train, rp_logger
):
rp_logger.info(
"Launch train process through CLI with valid 'drop_null' parameter equals 'True'"
@@ -233,7 +239,7 @@ def test_train_table_with_invalid_drop_null(rp_logger):
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
def test_train_table_with_valid_row_limit(
- mock_post_init, mock_launch_train, rp_logger
+ mock_post_init, mock_launch_train, rp_logger
):
rp_logger.info(
"Launch train process through CLI with valid 'row_limit' parameter equals 100"
@@ -262,18 +268,53 @@ def test_train_table_with_invalid_row_limit(rp_logger):
rp_logger.info(SUCCESSFUL_MESSAGE)
+@pytest.mark.parametrize("valid_value", TRAIN_REPORT_TYPES + ["none", "all"])
+@patch.object(Worker, "launch_train")
+@patch.object(Worker, "__attrs_post_init__")
+def test_train_table_with_valid_parameter_reports(
+ mock_post_init, mock_launch_train, valid_value, rp_logger
+):
+ rp_logger.info(
+ f"Launch train process through CLI with valid 'reports' parameter equals '{valid_value}'"
+ )
+ runner = CliRunner()
+ result = runner.invoke(
+ launch_train,
+ ["--reports", valid_value, "--table_name", TABLE_NAME, "--source", PATH_TO_TABLE],
+ )
+ mock_post_init.assert_called_once()
+ mock_launch_train.assert_called_once()
+ assert result.exit_code == 0
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.mark.parametrize(
+ "first_value, second_value",
+ [
+ (pv, i) for pv in TRAIN_REPORT_TYPES
+ for i in TRAIN_REPORT_TYPES
+ ]
+)
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
-def test_train_table_with_valid_print_report(
- mock_post_init, mock_launch_train, rp_logger
+def test_train_table_with_several_valid_parameter_reports(
+ mock_post_init, mock_launch_train, first_value, second_value, rp_logger
):
rp_logger.info(
- "Launch train process through CLI with valid 'print_report' parameter equals True"
+ f"Launch train process through CLI "
+ f"with several valid 'reports' parameters equals '{first_value}' and '{second_value}'"
)
runner = CliRunner()
result = runner.invoke(
launch_train,
- ["--print_report", True, "--table_name", TABLE_NAME, "--source", PATH_TO_TABLE],
+ [
+ "--reports",
+ first_value,
+ "--reports",
+ second_value,
+ "--table_name", TABLE_NAME,
+ "--source", PATH_TO_TABLE
+ ],
)
mock_post_init.assert_called_once()
mock_launch_train.assert_called_once()
@@ -281,30 +322,63 @@ def test_train_table_with_valid_print_report(
rp_logger.info(SUCCESSFUL_MESSAGE)
-def test_train_table_with_invalid_print_report(rp_logger):
+@pytest.mark.parametrize(
+ "invalid_value", ["test", ("none", "all"), ("none", "test"), ("all", "test")]
+)
+def test_train_table_with_invalid_parameter_reports(invalid_value, rp_logger):
rp_logger.info(
- "Launch train process through CLI with invalid 'print_report' parameter equals 'test'"
+ f"Launch train process through CLI "
+ f"with invalid 'reports' parameter equals '{invalid_value}'"
)
runner = CliRunner()
result = runner.invoke(
launch_train,
[
- "--print_report",
- "test",
+ "--reports",
+ invalid_value,
"--table_name",
TABLE_NAME,
"--source",
PATH_TO_TABLE,
],
)
- assert result.exit_code == 2
+ assert result.exit_code == 1
+ assert isinstance(result.exception, ValueError)
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.mark.parametrize(
+ "prior_value, value",
+ [(pv, i) for pv in ["all", "none"] for i in TRAIN_REPORT_TYPES]
+)
+def test_train_table_with_redundant_parameter_reports(prior_value, value, rp_logger):
+ rp_logger.info(
+ f"Launch train process through CLI with redundant 'reports' parameter: '{value}'"
+ )
+ runner = CliRunner()
+ result = runner.invoke(
+ launch_train, [
+ "--reports",
+ prior_value,
+ "--reports",
+ value,
+ "--table_name",
+ TABLE_NAME
+ ]
+ )
+ assert result.exit_code == 1
+ assert isinstance(result.exception, ValueError)
+ assert result.exception.args == (
+ "Invalid input: When '--reports' option is set to 'none' or 'all', "
+ "no other values should be provided.",)
+ rp_logger.info(SUCCESSFUL_MESSAGE)
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "launch_train")
@patch.object(Worker, "__attrs_post_init__")
def test_train_table_with_valid_batch_size(
- mock_post_init, mock_launch_train, rp_logger
+ mock_post_init, mock_launch_train, rp_logger
):
rp_logger.info(
"Launch train process through CLI with valid 'batch_size' parameter equals 100"
diff --git a/src/tests/unit/test_worker/fixtures/metadata.yaml b/src/tests/unit/test_worker/fixtures/metadata.yaml
index 620f82fd..b1162f38 100644
--- a/src/tests/unit/test_worker/fixtures/metadata.yaml
+++ b/src/tests/unit/test_worker/fixtures/metadata.yaml
@@ -1,9 +1,9 @@
-test_table:
+table:
train_settings:
- source: ./path/to/test_table.csv
+ source: ./path/to/table.csv
epochs: 100
drop_null: false
- print_report: false
+ reports: none
row_limit: 800
batch_size: 2000
@@ -11,9 +11,8 @@ test_table:
size: 200
run_parallel: True
random_seed: 2
- print_report: true
+ reports: all
batch_size: 200
- get_infer_metrics: false
keys:
pk_id:
diff --git a/src/tests/unit/test_worker/fixtures/metadata_of_related_tables.yaml b/src/tests/unit/test_worker/fixtures/metadata_of_related_tables.yaml
index ae06c4a0..65ab219f 100644
--- a/src/tests/unit/test_worker/fixtures/metadata_of_related_tables.yaml
+++ b/src/tests/unit/test_worker/fixtures/metadata_of_related_tables.yaml
@@ -8,7 +8,7 @@ pk_test:
infer_settings:
size: 200
run_parallel: True
- print_report: True
+ reports: all
keys:
pk_id:
@@ -21,14 +21,14 @@ fk_test:
source: ./path/to/fk_test.csv
epochs: 5
drop_null: true
- print_report: true
+ reports: all
row_limit: 600
infer_settings:
size: 90
run_parallel: True
random_seed: 2
- print_report: false
+ reports: none
keys:
fk_id:
diff --git a/src/tests/unit/test_worker/fixtures/metadata_with_empty_settings.yaml b/src/tests/unit/test_worker/fixtures/metadata_with_empty_settings.yaml
index d47521d8..7141b7f7 100644
--- a/src/tests/unit/test_worker/fixtures/metadata_with_empty_settings.yaml
+++ b/src/tests/unit/test_worker/fixtures/metadata_with_empty_settings.yaml
@@ -1,6 +1,6 @@
-test_table:
+table:
train_settings:
- source: ./path/to/test_table.csv
+ source: ./path/to/table.csv
infer_settings:
diff --git a/src/tests/unit/test_worker/fixtures/metadata_with_global_settings.yaml b/src/tests/unit/test_worker/fixtures/metadata_with_global_settings.yaml
index 6cddbf6e..c1784af5 100644
--- a/src/tests/unit/test_worker/fixtures/metadata_with_global_settings.yaml
+++ b/src/tests/unit/test_worker/fixtures/metadata_with_global_settings.yaml
@@ -6,7 +6,7 @@ global:
infer_settings:
size: 1000
run_parallel: True
- print_report: True
+ reports: all
pk_test:
train_settings:
@@ -14,7 +14,7 @@ pk_test:
row_limit: 800
infer_settings:
- print_report: False
+ reports: none
keys:
pk_id:
diff --git a/src/tests/unit/test_worker/fixtures/metadata_with_reports.yaml b/src/tests/unit/test_worker/fixtures/metadata_with_reports.yaml
new file mode 100644
index 00000000..4d5c8c36
--- /dev/null
+++ b/src/tests/unit/test_worker/fixtures/metadata_with_reports.yaml
@@ -0,0 +1,22 @@
+test_table:
+ train_settings:
+ source: ./path/to/test_table.csv
+ epochs: 100
+ drop_null: false
+ reports: all
+ row_limit: 800
+ batch_size: 2000
+
+ infer_settings:
+ destination: ./path/to/test_table_infer.csv
+ size: 200
+ run_parallel: True
+ random_seed: 2
+ reports: all
+ batch_size: 200
+
+ keys:
+ pk_id:
+ type: "PK"
+ columns:
+ - Id
diff --git a/src/tests/unit/test_worker/fixtures/metadata_without_reports.yaml b/src/tests/unit/test_worker/fixtures/metadata_without_reports.yaml
new file mode 100644
index 00000000..5893d2bd
--- /dev/null
+++ b/src/tests/unit/test_worker/fixtures/metadata_without_reports.yaml
@@ -0,0 +1,22 @@
+test_table:
+ train_settings:
+ source: ./path/to/test_table.csv
+ epochs: 100
+ drop_null: false
+ reports: none
+ row_limit: 800
+ batch_size: 2000
+
+ infer_settings:
+ destination: ./path/to/test_table_infer.csv
+ size: 200
+ run_parallel: True
+ random_seed: 2
+ reports: none
+ batch_size: 200
+
+ keys:
+ pk_id:
+ type: "PK"
+ columns:
+ - Id
diff --git a/src/tests/unit/test_worker/fixtures/metadata_without_sources.yaml b/src/tests/unit/test_worker/fixtures/metadata_without_sources.yaml
index 758742e3..e539b07d 100644
--- a/src/tests/unit/test_worker/fixtures/metadata_without_sources.yaml
+++ b/src/tests/unit/test_worker/fixtures/metadata_without_sources.yaml
@@ -7,7 +7,7 @@ pk_test:
infer_settings:
size: 200
run_parallel: True
- print_report: True
+ reports: all
keys:
pk_id:
@@ -19,14 +19,14 @@ fk_test:
train_settings:
epochs: 5
drop_null: true
- print_report: true
+ reports: all
row_limit: 600
infer_settings:
size: 90
run_parallel: True
random_seed: 2
- print_report: false
+ reports: none
keys:
fk_id:
diff --git a/src/tests/unit/test_worker/fixtures/metadata_without_train_settings.yaml b/src/tests/unit/test_worker/fixtures/metadata_without_train_settings.yaml
index 8dd22d08..935aa25e 100644
--- a/src/tests/unit/test_worker/fixtures/metadata_without_train_settings.yaml
+++ b/src/tests/unit/test_worker/fixtures/metadata_without_train_settings.yaml
@@ -2,7 +2,7 @@ pk_test:
infer_settings:
size: 200
run_parallel: True
- print_report: True
+ reports: all
keys:
pk_id:
@@ -15,7 +15,7 @@ fk_test:
size: 90
run_parallel: True
random_seed: 2
- print_report: false
+ reports: none
keys:
fk_id:
diff --git a/src/tests/unit/test_worker/test_worker.py b/src/tests/unit/test_worker/test_worker.py
index 5369746d..61d25bc9 100644
--- a/src/tests/unit/test_worker/test_worker.py
+++ b/src/tests/unit/test_worker/test_worker.py
@@ -1,4 +1,7 @@
from unittest.mock import patch, MagicMock
+import pytest
+
+from marshmallow.exceptions import ValidationError
from syngen.ml.worker import Worker
from syngen.ml.config import Validator
@@ -7,9 +10,9 @@
@patch.object(Validator, "run")
-def test_init_worker_for_training_process_with_absent_metadata(mock_validator_run, rp_logger):
+def test_init_worker_for_training_process_with_absent_metadata_path(mock_validator_run, rp_logger):
"""
- Test the initialization of 'Worker' class with the absent metadata
+ Test the initialization of 'Worker' class with the absent metadata path
during the training process
"""
rp_logger.info(
@@ -25,7 +28,7 @@ def test_init_worker_for_training_process_with_absent_metadata(mock_validator_ru
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
@@ -37,11 +40,12 @@ def test_init_worker_for_training_process_with_absent_metadata(mock_validator_ru
"batch_size": 1000,
"drop_null": True,
"epochs": 20,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 1000,
},
"infer_settings": {},
"keys": {},
+ "format": {}
}
}
mock_validator_run.assert_called_once()
@@ -49,9 +53,9 @@ def test_init_worker_for_training_process_with_absent_metadata(mock_validator_ru
@patch.object(Validator, "run")
-def test_init_worker_for_infer_process_with_absent_metadata(mock_validator_run, rp_logger):
+def test_init_worker_for_infer_process_with_absent_metadata_path(mock_validator_run, rp_logger):
"""
- Test the initialization of 'Worker' class with the absent metadata
+ Test the initialization of 'Worker' class with the absent metadata path
during the inference process
"""
rp_logger.info(
@@ -59,31 +63,30 @@ def test_init_worker_for_infer_process_with_absent_metadata(mock_validator_run,
"with the absent metadata during the inference process"
)
worker = Worker(
- table_name="test_table",
+ table_name="table",
metadata_path=None,
settings={
"size": 100,
"run_parallel": False,
"batch_size": 100,
- "print_report": False,
- "get_infer_metrics": False,
+ "reports": [],
"random_seed": 1,
},
log_level="INFO",
type_of_process="infer",
)
assert worker.metadata == {
- "test_table": {
+ "table": {
"train_settings": {"source": None},
"infer_settings": {
"size": 100,
"run_parallel": False,
"batch_size": 100,
- "print_report": False,
- "get_infer_metrics": False,
+ "reports": [],
"random_seed": 1,
},
"keys": {},
+ "format": {}
}
}
mock_validator_run.assert_called_once()
@@ -91,16 +94,12 @@ def test_init_worker_for_infer_process_with_absent_metadata(mock_validator_run,
@patch.object(Validator, "run")
-def test_init_worker_with_metadata(mock_validator_run, rp_logger):
+def test_init_worker_with_metadata_path(rp_logger):
"""
- Test the initialization of 'Worker' class with the metadata
- contained the information of one table with only the primary key
- during the training process
+ Test the initialization of 'Worker' class with the metadata path
"""
rp_logger.info(
- "Test the initialization of the instance of 'Worker' class "
- "with provided metadata contained the information of one table "
- "with only the primary key during the training process"
+ "Test the initialization of the instance of 'Worker' class with the metadata path"
)
worker = Worker(
table_name=None,
@@ -111,19 +110,19 @@ def test_init_worker_with_metadata(mock_validator_run, rp_logger):
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
)
assert worker.metadata == {
"global": {},
- "test_table": {
+ "table": {
"train_settings": {
- "source": "./path/to/test_table.csv",
+ "source": "./path/to/table.csv",
"epochs": 100,
"drop_null": False,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
"batch_size": 2000,
},
@@ -131,19 +130,23 @@ def test_init_worker_with_metadata(mock_validator_run, rp_logger):
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
- "get_infer_metrics": False,
+ "reports": ["accuracy"],
"batch_size": 200,
},
- "keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["Id"]
+ }
+ },
+ "format": {}
},
}
- mock_validator_run.assert_called_once()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "run")
-def test_init_worker_with_empty_settings_in_metadata(mock_validator_run, rp_logger):
+def test_init_worker_with_empty_settings_in_metadata_in_train_process(rp_logger):
"""
Test the initialization during the training process
of 'Worker' class with metadata contained the information of one table
@@ -156,42 +159,84 @@ def test_init_worker_with_empty_settings_in_metadata(mock_validator_run, rp_logg
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_with_empty_settings.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_empty_settings.yaml",
settings={
"source": None,
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
)
assert worker.metadata == {
"global": {},
- "test_table": {
+ "table": {
"train_settings": {
- "source": "./path/to/test_table.csv",
+ "source": "./path/to/table.csv",
"epochs": 20,
"drop_null": True,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 1000,
"batch_size": 1000,
},
"infer_settings": {},
"keys": {},
+ "format": {}
},
}
- mock_validator_run.assert_called_once()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "run")
-def test_init_worker_for_training_with_metadata_with_global_settings(
- mock_validator_run, rp_logger
-):
+def test_init_worker_with_empty_settings_in_metadata_in_infer_process(rp_logger):
+ """
+ Test the initialization during the inference process
+ of 'Worker' class with metadata contained the information of one table
+ in which the training, inference, keys settings are empty
+ """
+ rp_logger.info(
+ "Test the initialization of the instance of 'Worker' class with provided metadata "
+ "contained the information of one table in which 'train_settings', 'infer_settings', and "
+ "'keys' are empty during the inference process"
+ )
+ worker = Worker(
+ table_name=None,
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_empty_settings.yaml",
+ settings={
+ "size": 200,
+ "run_parallel": False,
+ "batch_size": 200,
+ "reports": ["accuracy"],
+ "random_seed": 5,
+ },
+ log_level="INFO",
+ type_of_process="infer",
+ )
+ assert worker.metadata == {
+ "global": {},
+ "table": {
+ "train_settings": {
+ "source": "./path/to/table.csv"
+ },
+ "infer_settings": {
+ "size": 200,
+ "run_parallel": False,
+ "batch_size": 200,
+ "random_seed": 5,
+ "reports": ["accuracy"]
+ },
+ "keys": {},
+ "format": {}
+ },
+ }
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Validator, "run")
+def test_init_worker_for_training_with_metadata_with_global_settings(rp_logger):
"""
Test the initialization of 'Worker' class during the training process
with the metadata contained related tables and global settings
@@ -202,24 +247,27 @@ def test_init_worker_for_training_with_metadata_with_global_settings(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_with_global_settings.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_global_settings.yaml",
settings={
"source": None,
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
)
assert worker.metadata == {
"global": {
- "train_settings": {"drop_null": True, "epochs": 5, "row_limit": 500},
+ "train_settings": {
+ "drop_null": True,
+ "epochs": 5,
+ "row_limit": 500
+ },
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"run_parallel": True,
"size": 1000,
},
@@ -231,10 +279,11 @@ def test_init_worker_for_training_with_metadata_with_global_settings(
"epochs": 5,
"drop_null": True,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
- "infer_settings": {"print_report": False},
+ "infer_settings": {"reports": []},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"keys": {
@@ -250,19 +299,17 @@ def test_init_worker_for_training_with_metadata_with_global_settings(
"drop_null": True,
"row_limit": 500,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {},
+ "format": {}
},
}
- mock_validator_run.assert_called_once()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "run")
-def test_init_worker_for_inference_with_metadata_with_global_settings(
- mock_validator_run, rp_logger
-):
+def test_init_worker_for_inference_with_metadata_with_global_settings(rp_logger):
"""
Test the initialization of 'Worker' class during an inference process
with metadata contained the information of related tables with the global settings
@@ -273,14 +320,12 @@ def test_init_worker_for_inference_with_metadata_with_global_settings(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_with_global_settings.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_global_settings.yaml",
settings={
"size": 200,
"run_parallel": False,
"batch_size": 200,
- "print_report": False,
- "get_infer_metrics": False,
+ "reports": [],
"random_seed": 5,
},
log_level="INFO",
@@ -294,7 +339,7 @@ def test_init_worker_for_inference_with_metadata_with_global_settings(
"row_limit": 500
},
"infer_settings": {
- "print_report": True,
+ "reports": ["accuracy"],
"run_parallel": True,
"size": 1000,
},
@@ -305,14 +350,14 @@ def test_init_worker_for_inference_with_metadata_with_global_settings(
"row_limit": 800
},
"infer_settings": {
- "print_report": False,
- "get_infer_metrics": False,
+ "reports": [],
"size": 1000,
"run_parallel": True,
"batch_size": 200,
"random_seed": 5,
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"keys": {
@@ -326,24 +371,24 @@ def test_init_worker_for_inference_with_metadata_with_global_settings(
"infer_settings": {
"size": 1000,
"run_parallel": True,
- "print_report": True,
- "get_infer_metrics": False,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 5,
},
+ "format": {}
},
}
- mock_validator_run.assert_called_once()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
+@patch.object(Worker, "_Worker__train_tables")
def test_launch_train_with_metadata(
mock_train_tables,
mock_gather_existed_columns,
@@ -351,6 +396,7 @@ def test_launch_train_with_metadata(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -373,22 +419,22 @@ def test_launch_train_with_metadata(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
)
worker.launch_train()
mock_train_tables.assert_called_once_with(
- ["test_table"],
- ["test_table"],
+ ["table"],
+ ["table"],
{
- "test_table": {
+ "table": {
"train_settings": {
- "source": "./path/to/test_table.csv",
+ "source": "./path/to/table.csv",
"epochs": 100,
"drop_null": False,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
"batch_size": 2000,
},
@@ -396,20 +442,20 @@ def test_launch_train_with_metadata(
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
- "get_infer_metrics": False,
+ "reports": ["accuracy"],
"batch_size": 200,
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
}
},
{
- "test_table": {
+ "table": {
"train_settings": {
- "source": "./path/to/test_table.csv",
+ "source": "./path/to/table.csv",
"epochs": 100,
"drop_null": False,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
"batch_size": 2000,
},
@@ -417,35 +463,37 @@ def test_launch_train_with_metadata(
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
- "get_infer_metrics": False,
+ "reports": ["accuracy"],
"batch_size": 200,
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
}
},
False
)
- mock_gather_existed_columns.assert_called_once()
- mock_check_existence_of_source.assert_called_once()
- mock_check_existence_of_key_columns.assert_called_once()
- mock_check_existence_of_referenced_columns.assert_called_once()
- mock_validate_metadata.assert_called_once()
+ mock_gather_existed_columns.assert_called_once_with("table")
+ mock_check_existence_of_source.assert_called_once_with("table")
+ mock_check_existence_of_key_columns.assert_called_once_with("table")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table")
+ mock_validate_metadata.assert_called_once_with("table")
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
- ["test_table"],
- ["test_table"],
+ ["table"],
+ ["table"],
False
)
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
+@patch.object(Worker, "_Worker__train_tables")
def test_launch_train_with_metadata_of_related_tables(
mock_train_tables,
mock_gather_existed_columns,
@@ -453,6 +501,7 @@ def test_launch_train_with_metadata_of_related_tables(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -466,15 +515,14 @@ def test_launch_train_with_metadata_of_related_tables(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_of_related_tables.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_of_related_tables.yaml",
settings={
"source": None,
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
@@ -491,21 +539,22 @@ def test_launch_train_with_metadata_of_related_tables(
"drop_null": False,
"row_limit": 800,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"train_settings": {
"source": "./path/to/fk_test.csv",
"epochs": 5,
"drop_null": True,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 600,
"batch_size": 1000,
},
@@ -513,7 +562,7 @@ def test_launch_train_with_metadata_of_related_tables(
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False,
+ "reports": [],
},
"keys": {
"fk_id": {
@@ -522,6 +571,7 @@ def test_launch_train_with_metadata_of_related_tables(
"references": {"table": "pk_test", "columns": ["Id"]},
}
},
+ "format": {}
},
},
{
@@ -532,21 +582,22 @@ def test_launch_train_with_metadata_of_related_tables(
"drop_null": False,
"row_limit": 800,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"train_settings": {
"source": "./path/to/fk_test.csv",
"epochs": 5,
"drop_null": True,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 600,
"batch_size": 1000,
},
@@ -554,7 +605,7 @@ def test_launch_train_with_metadata_of_related_tables(
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False,
+ "reports": [],
},
"keys": {
"fk_id": {
@@ -563,6 +614,7 @@ def test_launch_train_with_metadata_of_related_tables(
"references": {"table": "pk_test", "columns": ["Id"]},
}
},
+ "format": {}
},
},
True
@@ -572,6 +624,7 @@ def test_launch_train_with_metadata_of_related_tables(
assert mock_check_existence_of_key_columns.call_count == 2
assert mock_check_existence_of_referenced_columns.call_count == 2
assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
["pk_test", "fk_test"],
["pk_test", "fk_test"],
@@ -581,12 +634,13 @@ def test_launch_train_with_metadata_of_related_tables(
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
+@patch.object(Worker, "_Worker__train_tables",)
def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
mock_train_tables,
mock_gather_existed_columns,
@@ -594,6 +648,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -616,7 +671,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
@@ -633,7 +688,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"keys": {
"tdm_models_pkey": {"type": "PK", "columns": ["id"]},
@@ -644,6 +699,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
},
},
"infer_settings": {},
+ "format": {}
},
"tdm_clusters": {
"train_settings": {
@@ -652,10 +708,11 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"keys": {"tdm_clusters_pkey": {"type": "PK", "columns": ["id"]}},
"infer_settings": {},
+ "format": {}
},
},
{
@@ -666,7 +723,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"keys": {
"tdm_models_pkey": {"type": "PK", "columns": ["id"]},
@@ -677,6 +734,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
},
},
"infer_settings": {},
+ "format": {}
},
"tdm_clusters": {
"train_settings": {
@@ -685,10 +743,11 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"keys": {"tdm_clusters_pkey": {"type": "PK", "columns": ["id"]}},
"infer_settings": {},
+ "format": {}
},
"tdm_models_pk": {
"train_settings": {
@@ -697,10 +756,11 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"keys": {"tdm_models_pkey": {"type": "PK", "columns": ["id"]}},
"infer_settings": {},
+ "format": {}
},
"tdm_models_fk": {
"train_settings": {
@@ -709,7 +769,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"keys": {
"tdm_models_fkey": {
@@ -719,6 +779,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
}
},
"infer_settings": {},
+ "format": {}
},
},
True
@@ -728,6 +789,7 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
assert mock_check_existence_of_key_columns.call_count == 2
assert mock_check_existence_of_referenced_columns.call_count == 2
assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
["tdm_models", "tdm_clusters"],
["tdm_clusters", "tdm_models_pk", "tdm_models_fk"],
@@ -737,12 +799,13 @@ def test_launch_train_with_metadata_of_related_tables_with_diff_keys(
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
+@patch.object(Worker, "_Worker__train_tables")
def test_launch_train_without_metadata(
mock_train_tables,
mock_gather_existed_columns,
@@ -750,6 +813,7 @@ def test_launch_train_without_metadata(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -762,7 +826,7 @@ def test_launch_train_without_metadata(
"in case the metadata file wasn't provided and a training process was launched through CLI"
)
worker = Worker(
- table_name="test_table",
+ table_name="table",
metadata_path=None,
settings={
"source": "./path/to/source.csv",
@@ -770,65 +834,69 @@ def test_launch_train_without_metadata(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
)
worker.launch_train()
mock_train_tables.assert_called_once_with(
- ["test_table"],
- ["test_table"],
+ ["table"],
+ ["table"],
{
- "test_table": {
+ "table": {
"train_settings": {
"source": "./path/to/source.csv",
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {},
"keys": {},
+ "format": {}
}
},
{
- "test_table": {
+ "table": {
"train_settings": {
"source": "./path/to/source.csv",
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {},
"keys": {},
+ "format": {}
}
},
True
)
- mock_gather_existed_columns.assert_called_once()
- mock_check_existence_of_source.assert_called_once()
- mock_check_existence_of_key_columns.assert_called_once()
- mock_check_existence_of_referenced_columns.assert_called_once()
- mock_validate_metadata.assert_called_once()
+ mock_gather_existed_columns.assert_called_once_with("table")
+ mock_check_existence_of_source.assert_called_once_with("table")
+ mock_check_existence_of_key_columns.assert_called_once_with("table")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table")
+ mock_validate_metadata.assert_called_once_with("table")
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
- ["test_table"],
- ["test_table"],
+ ["table"],
+ ["table"],
True
)
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
+@patch.object(Worker, "_Worker__train_tables")
def test_launch_train_with_metadata_contained_global_settings(
mock_train_tables,
mock_gather_existed_columns,
@@ -836,6 +904,7 @@ def test_launch_train_with_metadata_contained_global_settings(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -850,15 +919,14 @@ def test_launch_train_with_metadata_contained_global_settings(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_with_global_settings.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_global_settings.yaml",
settings={
"source": None,
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
@@ -875,10 +943,11 @@ def test_launch_train_with_metadata_contained_global_settings(
"epochs": 5,
"drop_null": True,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
- "infer_settings": {"print_report": False},
+ "infer_settings": {"reports": []},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"keys": {
@@ -894,9 +963,10 @@ def test_launch_train_with_metadata_contained_global_settings(
"drop_null": True,
"row_limit": 500,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {},
+ "format": {}
},
},
{
@@ -907,10 +977,11 @@ def test_launch_train_with_metadata_contained_global_settings(
"epochs": 5,
"drop_null": True,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
- "infer_settings": {"print_report": False},
+ "infer_settings": {"reports": []},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"keys": {
@@ -926,9 +997,10 @@ def test_launch_train_with_metadata_contained_global_settings(
"drop_null": True,
"row_limit": 500,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
"infer_settings": {},
+ "format": {}
},
},
True
@@ -938,6 +1010,7 @@ def test_launch_train_with_metadata_contained_global_settings(
assert mock_check_existence_of_key_columns.call_count == 2
assert mock_check_existence_of_referenced_columns.call_count == 2
assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
["pk_test", "fk_test"],
["pk_test", "fk_test"],
@@ -947,13 +1020,17 @@ def test_launch_train_with_metadata_contained_global_settings(
@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
-@patch.object(Worker, "_Worker__infer_tables", return_value=None)
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Worker, "_Worker__infer_tables")
def test_launch_infer_with_metadata(
mock_infer_tables,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_infer,
rp_logger,
):
@@ -972,8 +1049,7 @@ def test_launch_infer_with_metadata(
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
- "get_infer_metrics": False,
+ "reports": ["accuracy", "sample"],
"batch_size": 200,
},
log_level="INFO",
@@ -981,14 +1057,14 @@ def test_launch_infer_with_metadata(
)
worker.launch_infer()
mock_infer_tables.assert_called_once_with(
- ["test_table"],
+ ["table"],
{
- "test_table": {
+ "table": {
"train_settings": {
- "source": "./path/to/test_table.csv",
+ "source": "./path/to/table.csv",
"epochs": 100,
"drop_null": False,
- "print_report": False,
+ "reports": [],
"row_limit": 800,
"batch_size": 2000,
},
@@ -996,30 +1072,36 @@ def test_launch_infer_with_metadata(
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
- "get_infer_metrics": False,
+ "reports": ["accuracy"],
"batch_size": 200,
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
}
},
0.25,
type_of_process="infer"
)
- mock_check_existence_of_destination.assert_called_once()
- mock_validate_metadata.assert_called_once()
- mock_collect_metrics_in_infer.assert_called_once_with(["test_table"])
+ mock_check_completion_of_training.assert_called_once_with("table")
+ mock_check_existence_of_destination.assert_called_once_with("table")
+ mock_validate_metadata.assert_called_once_with("table")
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_infer.assert_called_once_with(["table"])
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
-@patch.object(Worker, "_Worker__infer_tables", return_value=None)
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Worker, "_Worker__infer_tables")
def test_launch_infer_with_metadata_of_related_tables(
mock_infer_tables,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_infer,
rp_logger,
):
@@ -1033,12 +1115,11 @@ def test_launch_infer_with_metadata_of_related_tables(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_of_related_tables.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_of_related_tables.yaml",
settings={
"size": 300,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
@@ -1059,25 +1140,26 @@ def test_launch_infer_with_metadata_of_related_tables(
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"train_settings": {
"source": "./path/to/fk_test.csv",
"epochs": 5,
"drop_null": True,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 600,
},
"infer_settings": {
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False,
+ "reports": [],
"batch_size": 200,
},
"keys": {
@@ -1087,25 +1169,32 @@ def test_launch_infer_with_metadata_of_related_tables(
"references": {"table": "pk_test", "columns": ["Id"]},
}
},
+ "format": {}
},
},
0.125,
type_of_process="infer"
)
+ assert mock_check_completion_of_training.call_count == 2
assert mock_check_existence_of_destination.call_count == 2
assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_infer.assert_called_once_with(["pk_test", "fk_test"])
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
-@patch.object(Worker, "_Worker__infer_tables", return_value=None)
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Worker, "_Worker__infer_tables")
def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
mock_infer_tables,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_infer,
rp_logger,
):
@@ -1126,7 +1215,7 @@ def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
settings={
"size": 300,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
@@ -1150,10 +1239,11 @@ def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
"infer_settings": {
"size": 300,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
+ "format": {}
},
"tdm_clusters": {
"train_settings": {"source": "./path/to/tdm_clusters.csv"},
@@ -1161,10 +1251,11 @@ def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
"infer_settings": {
"size": 300,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
+ "format": {}
},
"tdm_models_pk": {
"train_settings": {"source": "./path/to/tdm_models.csv"},
@@ -1172,10 +1263,11 @@ def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
"infer_settings": {
"size": 300,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
+ "format": {}
},
"tdm_models_fk": {
"train_settings": {"source": "./path/to/tdm_models.csv"},
@@ -1189,17 +1281,20 @@ def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
"infer_settings": {
"size": 300,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
"random_seed": 1,
},
+ "format": {}
},
},
0.08333333333333333,
type_of_process="infer"
)
+ assert mock_check_completion_of_training.call_count == 2
assert mock_check_existence_of_destination.call_count == 2
assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_infer.assert_called_once_with(
["tdm_clusters", "tdm_models_pk", "tdm_models_fk"]
)
@@ -1207,13 +1302,17 @@ def test_launch_infer_with_metadata_of_related_tables_with_diff_keys(
@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
-@patch.object(Worker, "_Worker__infer_tables", return_value=None)
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Worker, "_Worker__infer_tables")
def test_launch_infer_without_metadata(
mock_infer_tables,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_infer,
rp_logger,
):
@@ -1227,13 +1326,13 @@ def test_launch_infer_without_metadata(
"the inference process was launched through CLI"
)
worker = Worker(
- table_name="test_table",
+ table_name="table",
metadata_path=None,
settings={
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
},
log_level="INFO",
@@ -1241,37 +1340,44 @@ def test_launch_infer_without_metadata(
)
worker.launch_infer()
mock_infer_tables.assert_called_once_with(
- ["test_table"],
+ ["table"],
{
- "test_table": {
+ "table": {
"train_settings": {"source": None},
"infer_settings": {
"size": 200,
"run_parallel": True,
"random_seed": 2,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 200,
},
"keys": {},
+ "format": {}
}
},
0.25,
type_of_process="infer"
)
- mock_check_existence_of_destination.assert_called_once()
- mock_validate_metadata.assert_called_once()
- mock_collect_metrics_in_infer.assert_called_once_with(["test_table"])
+ mock_check_completion_of_training.assert_called_once_with("table")
+ mock_check_existence_of_destination.assert_called_once_with("table")
+ mock_validate_metadata.assert_called_once_with("table")
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_infer.assert_called_once_with(["table"])
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_destination")
-@patch.object(Worker, "_Worker__infer_tables", return_value=None)
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Worker, "_Worker__infer_tables")
def test_launch_infer_with_metadata_contained_global_settings(
mock_infer_tables,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_infer,
rp_logger,
):
@@ -1286,13 +1392,12 @@ def test_launch_infer_with_metadata_contained_global_settings(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_with_global_settings.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_global_settings.yaml",
settings={
"size": 300,
"run_parallel": True,
"random_seed": 3,
- "print_report": True,
+ "reports": ["accuracy"],
"batch_size": 300,
},
log_level="INFO",
@@ -1305,13 +1410,14 @@ def test_launch_infer_with_metadata_contained_global_settings(
"pk_test": {
"train_settings": {"source": "./path/to/pk_test.csv", "row_limit": 800},
"infer_settings": {
- "print_report": False,
+ "reports": [],
"size": 1000,
"run_parallel": True,
"random_seed": 3,
"batch_size": 300,
},
"keys": {"pk_id": {"type": "PK", "columns": ["Id"]}},
+ "format": {}
},
"fk_test": {
"keys": {
@@ -1325,35 +1431,421 @@ def test_launch_infer_with_metadata_contained_global_settings(
"infer_settings": {
"size": 1000,
"run_parallel": True,
- "print_report": True,
+ "reports": ["accuracy"],
"random_seed": 3,
"batch_size": 300,
},
+ "format": {}
},
},
0.125,
type_of_process="infer",
)
+ assert mock_check_completion_of_training.call_count == 2
assert mock_check_existence_of_destination.call_count == 2
assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_infer.assert_called_once_with(["pk_test", "fk_test"])
rp_logger.info(SUCCESSFUL_MESSAGE)
+@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
+@patch.object(Validator, "_validate_metadata")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Validator, "_check_existence_of_referenced_columns")
+@patch.object(Validator, "_check_existence_of_key_columns")
+@patch.object(Validator, "_check_existence_of_source")
+@patch.object(Validator, "_gather_existed_columns")
+@patch.object(Worker, "_infer_table")
+@patch.object(Worker, "_train_table")
+def test_train_tables_without_generation_reports(
+ mock_train_table,
+ mock_infer_table,
+ mock_gather_existed_columns,
+ mock_check_existence_of_source,
+ mock_check_existence_of_key_columns,
+ mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
+ mock_check_existence_of_destination,
+ mock_validate_metadata,
+ mock_generate_reports,
+ mock_collect_metrics_in_train,
+ rp_logger,
+):
+ """
+ Test the '__train_tables' method of the 'Worker' class
+ in case the reports won't be generated
+ """
+ rp_logger.info(
+ "Test the '__train_tables' method of the 'Worker' class "
+ "in case the reports won't be generated"
+ )
+ worker = Worker(
+ table_name=None,
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_without_reports.yaml",
+ settings={},
+ log_level="INFO",
+ type_of_process="train",
+ )
+ worker.launch_train()
+ mock_gather_existed_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_source.assert_called_once_with("test_table")
+ mock_check_existence_of_key_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("test_table")
+ mock_check_completion_of_training.assert_not_called()
+ mock_check_existence_of_destination.assert_not_called()
+ mock_validate_metadata.assert_called_once_with("test_table")
+ mock_train_table.assert_called_once_with(
+ "test_table",
+ {
+ "test_table": {
+ "train_settings": {
+ "source": "./path/to/test_table.csv",
+ "epochs": 100,
+ "drop_null": False,
+ "reports": [],
+ "row_limit": 800,
+ "batch_size": 2000
+ },
+ "infer_settings": {
+ "destination": "./path/to/test_table_infer.csv",
+ "size": 200,
+ "run_parallel": True,
+ "random_seed": 2,
+ "reports": [],
+ "batch_size": 200
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["Id"]
+ }
+ },
+ "format": {}
+ }
+ },
+ 0.49
+ )
+ mock_infer_table.assert_not_called()
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_train.assert_called_once_with(["test_table"], ["test_table"], False)
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
-def test_init_worker_for_training_process_with_absent_metadata_and_callback_loader(
- mock_train_tables,
- mock_gather_existed_columns,
- mock_check_existence_of_source,
- mock_check_existence_of_key_columns,
- mock_check_existence_of_referenced_columns,
- mock_validate_metadata,
- rp_logger
+@patch.object(Worker, "_infer_table")
+@patch.object(Worker, "_train_table")
+def test_train_tables_with_generation_reports(
+ mock_train_table,
+ mock_infer_table,
+ mock_gather_existed_columns,
+ mock_check_existence_of_source,
+ mock_check_existence_of_key_columns,
+ mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
+ mock_check_existence_of_destination,
+ mock_validate_metadata,
+ mock_generate_reports,
+ mock_collect_metrics_in_train,
+ rp_logger,
+):
+ """
+ Test the '__train_tables' method of the 'Worker' class
+ in case the reports will be generated
+ """
+ rp_logger.info(
+ "Test the '__train_tables' method of the 'Worker' class "
+ "in case the reports will be generated"
+ )
+ worker = Worker(
+ table_name=None,
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_reports.yaml",
+ settings={},
+ log_level="INFO",
+ type_of_process="train"
+ )
+ worker.launch_train()
+ mock_gather_existed_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_source.assert_called_once_with("test_table")
+ mock_check_existence_of_key_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("test_table")
+ mock_check_completion_of_training.assert_not_called()
+ mock_check_existence_of_destination.assert_not_called()
+ mock_validate_metadata.assert_called_once_with("test_table")
+ mock_train_table.assert_called_once_with(
+ "test_table",
+ {
+ "test_table": {
+ "train_settings": {
+ "source": "./path/to/test_table.csv",
+ "epochs": 100,
+ "drop_null": False,
+ "reports": ["accuracy", "sample"],
+ "row_limit": 800,
+ "batch_size": 2000
+ },
+ "infer_settings": {
+ "destination": "./path/to/test_table_infer.csv",
+ "size": 200,
+ "run_parallel": True,
+ "random_seed": 2,
+ "reports": ["accuracy"],
+ "batch_size": 200
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["Id"]
+ }
+ },
+ "format": {}
+ }
+ },
+ 0.49
+ )
+ mock_infer_table.assert_called_once_with(
+ table="test_table",
+ metadata={
+ "test_table": {
+ "train_settings": {
+ "source": "./path/to/test_table.csv",
+ "epochs": 100,
+ "drop_null": False,
+ "reports": ["accuracy", "sample"],
+ "row_limit": 800,
+ "batch_size": 2000
+ },
+ "infer_settings": {
+ "destination": "./path/to/test_table_infer.csv",
+ "size": 200,
+ "run_parallel": True,
+ "random_seed": 2,
+ "reports": ["accuracy"],
+ "batch_size": 200
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["Id"]
+ }
+ },
+ "format": {}
+ }
+ },
+ type_of_process="train",
+ delta=0.49
+ )
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_train.assert_called_once_with(["test_table"], ["test_table"], True)
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
+@patch.object(Validator, "_validate_metadata")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Validator, "_check_existence_of_referenced_columns")
+@patch.object(Validator, "_check_existence_of_key_columns")
+@patch.object(Validator, "_check_existence_of_source")
+@patch.object(Validator, "_gather_existed_columns")
+@patch.object(Worker, "_infer_table")
+def test_infer_tables_without_generation_reports(
+ mock_infer_table,
+ mock_gather_existed_columns,
+ mock_check_existence_of_source,
+ mock_check_existence_of_key_columns,
+ mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
+ mock_check_existence_of_destination,
+ mock_validate_metadata,
+ mock_generate_reports,
+ mock_collect_metrics_in_infer,
+ rp_logger,
+):
+ """
+ Test the '__infer_tables' method of the 'Worker' class
+ in case the reports won't be generated
+ """
+ rp_logger.info(
+ "Test the '__infer_tables' method of the 'Worker' class "
+ "in case the reports won't be generated"
+ )
+ worker = Worker(
+ table_name=None,
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_without_reports.yaml",
+ settings={
+ "size": 300,
+ "run_parallel": True,
+ "random_seed": 3,
+ "reports": ["accuracy"],
+ "batch_size": 300
+ },
+ log_level="INFO",
+ type_of_process="infer"
+ )
+ worker.launch_infer()
+ mock_gather_existed_columns.assert_not_called()
+ mock_check_existence_of_source.assert_not_called()
+ mock_check_existence_of_key_columns.assert_not_called()
+ mock_check_existence_of_referenced_columns.assert_not_called()
+ mock_check_completion_of_training.assert_called_once_with("test_table")
+ mock_check_existence_of_destination.assert_called_once_with("test_table")
+ mock_validate_metadata.assert_called_once_with("test_table")
+ mock_infer_table.assert_called_once_with(
+ table="test_table",
+ metadata={
+ "test_table": {
+ "train_settings": {
+ "source": "./path/to/test_table.csv",
+ "epochs": 100,
+ "drop_null": False,
+ "reports": [],
+ "row_limit": 800,
+ "batch_size": 2000
+ },
+ "infer_settings": {
+ "destination": "./path/to/test_table_infer.csv",
+ "size": 200,
+ "run_parallel": True,
+ "random_seed": 2,
+ "reports": [],
+ "batch_size": 200
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["Id"]
+ }
+ },
+ "format": {}
+ }
+ },
+ type_of_process="infer",
+ delta=0.5
+ )
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_infer.assert_called_once_with(["test_table"])
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
+@patch.object(Validator, "_validate_metadata")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Validator, "_check_existence_of_referenced_columns")
+@patch.object(Validator, "_check_existence_of_key_columns")
+@patch.object(Validator, "_check_existence_of_source")
+@patch.object(Validator, "_gather_existed_columns")
+@patch.object(Worker, "_infer_table")
+def test_infer_tables_with_generation_reports(
+ mock_infer_table,
+ mock_gather_existed_columns,
+ mock_check_existence_of_source,
+ mock_check_existence_of_key_columns,
+ mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
+ mock_check_existence_of_destination,
+ mock_validate_metadata,
+ mock_generate_reports,
+ mock_collect_metrics_in_infer,
+ rp_logger,
+):
+ """
+ Test the '__infer_tables' method of the 'Worker' class
+ in case the reports should be generated
+ """
+ rp_logger.info(
+ "Test the '__infer_tables' method of the 'Worker' class "
+ "in case the reports will be generated"
+ )
+ worker = Worker(
+ table_name=None,
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_with_reports.yaml",
+ settings={
+ "size": 300,
+ "run_parallel": True,
+ "random_seed": 3,
+ "reports": ["accuracy"],
+ "batch_size": 300,
+ },
+ log_level="INFO",
+ type_of_process="infer"
+ )
+ worker.launch_infer()
+ mock_gather_existed_columns.assert_not_called()
+ mock_check_existence_of_source.assert_not_called()
+ mock_check_existence_of_key_columns.assert_not_called()
+ mock_check_existence_of_referenced_columns.assert_not_called()
+ mock_check_completion_of_training.assert_called_once_with("test_table")
+ mock_check_existence_of_destination.assert_called_once_with("test_table")
+ mock_validate_metadata.assert_called_once_with("test_table")
+ mock_infer_table.assert_called_once_with(
+ table="test_table",
+ metadata={
+ "test_table": {
+ "train_settings": {
+ "source": "./path/to/test_table.csv",
+ "epochs": 100,
+ "drop_null": False,
+ "reports": ["accuracy", "sample"],
+ "row_limit": 800,
+ "batch_size": 2000
+ },
+ "infer_settings": {
+ "destination": "./path/to/test_table_infer.csv",
+ "size": 200,
+ "run_parallel": True,
+ "random_seed": 2,
+ "reports": ["accuracy"],
+ "batch_size": 200
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["Id"]
+ }
+ },
+ "format": {}
+ }
+ },
+ type_of_process="infer",
+ delta=0.25
+ )
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_infer.assert_called_once_with(["test_table"])
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
+@patch.object(Validator, "_validate_metadata")
+@patch.object(Validator, "_check_existence_of_referenced_columns")
+@patch.object(Validator, "_check_existence_of_key_columns")
+@patch.object(Validator, "_check_existence_of_source")
+@patch.object(Validator, "_gather_existed_columns")
+@patch.object(Worker, "_Worker__train_tables")
+def test_launch_train_with_absent_metadata_and_callback_loader(
+ mock_train_tables,
+ mock_gather_existed_columns,
+ mock_check_existence_of_source,
+ mock_check_existence_of_key_columns,
+ mock_check_existence_of_referenced_columns,
+ mock_validate_metadata,
+ mock_generate_reports,
+ mock_collect_metrics_in_train,
+ rp_logger
):
"""
Test the initialization of 'Worker' class
@@ -1365,63 +1857,66 @@ def test_init_worker_for_training_process_with_absent_metadata_and_callback_load
"with the absent metadata and provided callback function during the training process"
)
worker = Worker(
- table_name="test_table",
+ table_name="table",
metadata_path=None,
settings={
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
loader=MagicMock()
)
assert worker.metadata == {
- "test_table": {
+ "table": {
"train_settings": {
"source": None,
"batch_size": 1000,
"drop_null": True,
"epochs": 20,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 1000,
},
"infer_settings": {},
"keys": {},
+ "format": {}
}
}
worker.launch_train()
mock_train_tables.assert_called_with(
- ["test_table"],
- ["test_table"],
+ ["table"],
+ ["table"],
{
- "test_table": {
+ "table": {
"train_settings": {
"source": None,
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {},
- "keys": {}
+ "keys": {},
+ "format": {}
}
},
{
- "test_table": {
+ "table": {
"train_settings": {
"source": None,
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {},
- "keys": {}
+ "keys": {},
+ "format": {}
}
},
True
@@ -1431,23 +1926,30 @@ def test_init_worker_for_training_process_with_absent_metadata_and_callback_load
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
mock_validate_metadata.assert_called_once()
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_train.assert_called_once_with(
+ ["table"],
+ ["table"],
+ True)
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
-def test_launch_train_with_metadata_without_source_paths(
+@patch.object(Worker, "_Worker__train_tables")
+def test_launch_train_with_metadata_without_source_paths_and_loader(
mock_train_tables,
mock_gather_existed_columns,
mock_check_existence_of_source,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -1471,7 +1973,7 @@ def test_launch_train_with_metadata_without_source_paths(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
@@ -1488,25 +1990,26 @@ def test_launch_train_with_metadata_without_source_paths(
"drop_null": False,
"row_limit": 800,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True
+ "reports": ["accuracy"]
},
"keys": {
"pk_id": {
"type": "PK",
"columns": ["Id"]
}
- }
+ },
+ "format": {}
},
"fk_test": {
"train_settings": {
"epochs": 5,
"drop_null": True,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 600,
"batch_size": 1000
},
@@ -1514,7 +2017,7 @@ def test_launch_train_with_metadata_without_source_paths(
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False
+ "reports": []
},
"keys": {
"fk_id": {
@@ -1525,7 +2028,8 @@ def test_launch_train_with_metadata_without_source_paths(
"columns": ["Id"]
}
}
- }
+ },
+ "format": {}
}
},
{
@@ -1535,25 +2039,26 @@ def test_launch_train_with_metadata_without_source_paths(
"drop_null": False,
"row_limit": 800,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True
+ "reports": ["accuracy"]
},
"keys": {
"pk_id": {
"type": "PK",
"columns": ["Id"]
}
- }
+ },
+ "format": {}
},
"fk_test": {
"train_settings": {
"epochs": 5,
"drop_null": True,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
"row_limit": 600,
"batch_size": 1000
},
@@ -1561,7 +2066,7 @@ def test_launch_train_with_metadata_without_source_paths(
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False
+ "reports": []
},
"keys": {
"fk_id": {
@@ -1572,38 +2077,42 @@ def test_launch_train_with_metadata_without_source_paths(
"columns": ["Id"]
}
}
- }
+ },
+ "format": {}
}
},
True
)
mock_gather_existed_columns.assert_not_called()
+ mock_check_existence_of_source.assert_not_called()
+ mock_check_existence_of_key_columns.assert_not_called()
+ mock_check_existence_of_referenced_columns.assert_not_called()
+ assert mock_validate_metadata.call_count == 2
+ mock_generate_reports.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
["pk_test", "fk_test"],
["pk_test", "fk_test"],
True
)
- mock_check_existence_of_source.assert_not_called()
- mock_check_existence_of_key_columns.assert_not_called()
- mock_check_existence_of_referenced_columns.assert_not_called()
- assert mock_validate_metadata.call_count == 2
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Worker, "_collect_metrics_in_train")
+@patch.object(Worker, "_generate_reports")
@patch.object(Validator, "_validate_metadata")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-@patch.object(Worker, "_Worker__train_tables", return_value=None)
-def test_launch_train_with_metadata_without_train_settings(
+@patch.object(Worker, "_Worker__train_tables")
+def test_launch_train_with_metadata_without_train_settings_and_loader(
mock_train_tables,
mock_gather_existed_columns,
mock_check_existence_of_source,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_metadata,
+ mock_generate_reports,
mock_collect_metrics_in_train,
rp_logger,
):
@@ -1621,14 +2130,13 @@ def test_launch_train_with_metadata_without_train_settings(
)
worker = Worker(
table_name=None,
- metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/"
- "metadata_without_train_settings.yaml",
+ metadata_path=f"{DIR_NAME}/unit/test_worker/fixtures/metadata_without_train_settings.yaml",
settings={
"epochs": 20,
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True,
+ "reports": ["accuracy", "sample"],
},
log_level="INFO",
type_of_process="train",
@@ -1645,19 +2153,20 @@ def test_launch_train_with_metadata_without_train_settings(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True
+ "reports": ["accuracy"]
},
"keys": {
"pk_id": {
"type": "PK",
"columns": ["Id"]
}
- }
+ },
+ "format": {}
},
"fk_test": {
"train_settings": {
@@ -1665,13 +2174,13 @@ def test_launch_train_with_metadata_without_train_settings(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False
+ "reports": []
},
"keys": {
"fk_id": {
@@ -1682,7 +2191,8 @@ def test_launch_train_with_metadata_without_train_settings(
"columns": ["Id"]
}
}
- }
+ },
+ "format": {}
}
},
{
@@ -1692,19 +2202,20 @@ def test_launch_train_with_metadata_without_train_settings(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"size": 200,
"run_parallel": True,
- "print_report": True
+ "reports": ["accuracy"]
},
"keys": {
"pk_id": {
"type": "PK",
"columns": ["Id"]
}
- }
+ },
+ "format": {}
},
"fk_test": {
"train_settings": {
@@ -1712,13 +2223,13 @@ def test_launch_train_with_metadata_without_train_settings(
"drop_null": True,
"row_limit": 1000,
"batch_size": 1000,
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"size": 90,
"run_parallel": True,
"random_seed": 2,
- "print_report": False
+ "reports": []
},
"keys": {
"fk_id": {
@@ -1729,7 +2240,8 @@ def test_launch_train_with_metadata_without_train_settings(
"columns": ["Id"]
}
}
- }
+ },
+ "format": {}
}
},
True
@@ -1739,9 +2251,177 @@ def test_launch_train_with_metadata_without_train_settings(
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
assert mock_validate_metadata.call_count == 2
+ mock_collect_metrics_in_train.assert_called_once()
mock_collect_metrics_in_train.assert_called_once_with(
["pk_test", "fk_test"],
["pk_test", "fk_test"],
True
)
rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Worker, "_collect_metrics_in_infer")
+@patch.object(Worker, "_generate_reports")
+@patch.object(Worker, "_infer_table")
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_validate_metadata")
+def test_launch_infer_of_pretrained_table(
+ mock_validate_metadata,
+ mock_check_existence_of_destination,
+ mock_check_completion_of_training,
+ mock_infer_table,
+ mock_generate_reports,
+ mock_collect_metrics_in_infer,
+ rp_logger,
+):
+ """
+ Test that the inference process has been launched
+ if the training process of the table has been finished
+ """
+ rp_logger.info(
+ "Test that the inference process has been launched "
+ "if the training process of the table has been finished"
+ )
+ worker = Worker(
+ table_name="table",
+ metadata_path=None,
+ settings={
+ "size": 300,
+ "run_parallel": True,
+ "random_seed": 3,
+ "reports": ["accuracy"],
+ "batch_size": 300,
+ },
+ log_level="INFO",
+ type_of_process="infer",
+ loader=None
+ )
+ metadata = {
+ "table": {
+ "train_settings": {
+ "source": None
+ },
+ "infer_settings": {
+ "size": 300,
+ "run_parallel": True,
+ "random_seed": 3,
+ "reports": ["accuracy"],
+ "batch_size": 300
+ },
+ "keys": {},
+ "format": {}
+ }
+ }
+ worker.launch_infer()
+ mock_validate_metadata.assert_called_once_with("table")
+ mock_check_existence_of_destination.assert_called_once_with("table")
+ mock_check_completion_of_training.assert_called_once_with("table")
+ mock_infer_table.assert_called_once_with(
+ table="table",
+ metadata=metadata,
+ type_of_process="infer",
+ delta=0.25
+ )
+ mock_generate_reports.assert_called_once()
+ mock_collect_metrics_in_infer.assert_called_once_with(["table"])
+
+
+@patch.object(Worker, "_infer_table")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_validate_metadata")
+def test_launch_infer_of_not_pretrained_table_and_absent_success_file(
+ mock_validate_metadata,
+ mock_check_existence_of_destination,
+ mock_infer_table,
+ caplog,
+ rp_logger,
+):
+ """
+ Test that the inference process hasn't been started
+ in case the training process of the table hasn't been finished,
+ and the appropriate success file 'message.success' is absent
+ """
+ rp_logger.info(
+ "Test that the inference process hasn't been started "
+ "in case the training process of the table hasn't been finished, "
+ "and the appropriate success file 'message.success' is absent"
+ )
+ with pytest.raises(ValidationError) as error:
+ with caplog.at_level("ERROR"):
+ worker = Worker(
+ table_name="table",
+ metadata_path=None,
+ settings={
+ "size": 300,
+ "run_parallel": True,
+ "random_seed": 3,
+ "reports": ["accuracy"],
+ "batch_size": 300,
+ },
+ log_level="INFO",
+ type_of_process="infer",
+ loader=None
+ )
+ worker.launch_infer()
+ message = (
+ "The training of the table - 'table' hasn't been completed. "
+ "Please, retrain the table."
+ )
+ assert message in str(error.value)
+ assert message in caplog.text
+ mock_check_existence_of_destination.assert_called_once_with("table")
+ mock_validate_metadata.assert_called_once_with("table")
+ mock_infer_table.assert_not_called()
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Worker, "_infer_table")
+@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_validate_metadata")
+def test_launch_infer_of_not_pretrained_table_and_success_file_with_wrong_content(
+ mock_validate_metadata,
+ mock_check_existence_of_destination,
+ mock_infer_table,
+ test_success_file,
+ caplog,
+ rp_logger,
+):
+ """
+ Test that the inference process hasn't been started
+ in case the training process of the table hasn't been finished,
+ and the appropriate success file 'message.success' is present,
+ but the content of the file doesn't correspond to finished training process
+ """
+ rp_logger.info(
+ "Test that the inference process hasn't been started "
+ "in case the training process of the table hasn't been finished, "
+ "and the appropriate success file 'message.success' is present, "
+ "but the content of the file doesn't correspond to finished training process"
+ )
+ with pytest.raises(ValidationError) as error:
+ with caplog.at_level("ERROR"):
+ worker = Worker(
+ table_name="table",
+ metadata_path=None,
+ settings={
+ "size": 300,
+ "run_parallel": True,
+ "random_seed": 3,
+ "reports": ["accuracy"],
+ "batch_size": 300,
+ },
+ log_level="INFO",
+ type_of_process="infer",
+ loader=None
+ )
+ worker.launch_infer()
+ message = (
+ "The training of the table - 'table' hasn't been completed. "
+ "Please, retrain the table."
+ )
+ assert message in str(error.message)
+ assert message in caplog.text
+ mock_check_existence_of_destination.assert_called_once_with("table")
+ mock_infer_table.assert_not_called()
+ rp_logger.info(SUCCESSFUL_MESSAGE)
diff --git a/src/tests/unit/validation_metadata/test_validation_metadata.py b/src/tests/unit/validation_metadata/test_validation_metadata.py
index 60c2221d..c54139e9 100644
--- a/src/tests/unit/validation_metadata/test_validation_metadata.py
+++ b/src/tests/unit/validation_metadata/test_validation_metadata.py
@@ -1,4 +1,4 @@
-from unittest.mock import patch
+from unittest.mock import patch, call
import pytest
from marshmallow import ValidationError
@@ -8,7 +8,7 @@
FAKE_METADATA_PATH = "path/to/metadata.yaml"
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -20,48 +20,49 @@ def test_validate_metadata_of_one_table_without_fk_key_in_train_process(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
rp_logger
):
"""
Test the validation of the metadata of one table contained only the primary key
- used in the training process
+ during the training process
"""
rp_logger.info(
"Test the validation of the metadata of one table contained the primary key "
- "used in the training process"
+ "during the training process"
)
test_metadata = {
- "test_table": {
- "train_settings": {
- "source": "path/to/test_table.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/test_table.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
- mock_gather_existed_columns.assert_called_once()
- mock_check_existence_of_source.assert_called_once()
- mock_check_existence_of_key_columns.assert_called_once()
- mock_check_existence_of_referenced_columns.assert_called_once()
+ mock_gather_existed_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_source.assert_called_once_with("test_table")
+ mock_check_existence_of_key_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("test_table")
mock_validate_referential_integrity.assert_not_called()
- mock_check_existence_of_success_file.assert_not_called()
+ mock_check_completion_of_training.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -73,48 +74,47 @@ def test_validate_metadata_of_one_table_without_fk_key_in_train_process_without_
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
rp_logger
):
"""
- Test the validation of the metadata of one table
- contained only the primary key
- used in the training process
- with 'validation_source' set to 'False'
+ Test the validation of the metadata of one table contained only the primary key
+ during the training process with 'validation_source' set to 'False'
"""
rp_logger.info(
"Test the validation of the metadata of one table "
- "contained only the primary ke used in the training process "
+ "contained only the primary key during the training process "
"with 'validation_source' set to 'False'"
)
test_metadata = {
- "table_a": {
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "table_a": {
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
- },
- "table_b": {
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
- }
+ }
+ },
+ "table_b": {
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH,
validation_source=False
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -123,7 +123,7 @@ def test_validate_metadata_of_one_table_without_fk_key_in_train_process_without_
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
mock_validate_referential_integrity.assert_not_called()
- mock_check_existence_of_success_file.assert_not_called()
+ mock_check_completion_of_training.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -135,24 +135,24 @@ def test_check_key_column_in_pk(rp_logger):
"Test that the column of the primary key exists in the source table"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/"
- "csv_tables/table_with_data.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "table_a": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/table_with_data.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -201,6 +201,7 @@ def test_check_key_column_in_fk(rp_logger):
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_id": {
@@ -212,7 +213,7 @@ def test_check_key_column_in_fk(rp_logger):
rp_logger.info(SUCCESSFUL_MESSAGE)
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -224,60 +225,61 @@ def test_validate_metadata_of_related_tables_with_fk_key_in_train_process(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
rp_logger
):
"""
Test the validation of the metadata of related tables
contained the primary key and the foreign key
- used in the training process
+ during the training process
"""
rp_logger.info(
"Test the validation of the metadata of related tables "
- "contained only the primary key and the foreign key used in the training process"
+ "contained only the primary key and the foreign key during the training process"
)
test_metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "fk_id": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
- }
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv"
},
- "table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv"
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
+ "fk_id": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
"columns": ["id"]
}
}
+ }
+ },
+ "table_a": {
+ "train_settings": {
+ "source": "path/to/table_a.csv"
},
- }
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ }
+ }
+ },
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_id": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == test_metadata
@@ -285,12 +287,33 @@ def test_validate_metadata_of_related_tables_with_fk_key_in_train_process(
assert mock_check_existence_of_source.call_count == 2
assert mock_check_existence_of_key_columns.call_count == 2
assert mock_check_existence_of_referenced_columns.call_count == 2
- assert mock_validate_referential_integrity.call_count == 1
- mock_check_existence_of_success_file.assert_not_called()
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_id",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ }
+ }
+ }
+ )
+ mock_check_completion_of_training.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -302,56 +325,57 @@ def test_validate_metadata_of_related_tables_with_fk_key_in_train_process_withou
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
rp_logger
):
"""
Test the validation of the metadata of related tables
contained the primary key and the foreign key
- used in the training process with 'validation_source' set to 'False'
+ during the training process with 'validation_source' set to 'False'
"""
rp_logger.info(
"Test the validation of the metadata of related tables "
"contained the primary key and the foreign key "
- "used in the training process with 'validation_source' set to 'False'"
+ "during the training process with 'validation_source' set to 'False'"
)
test_metadata = {
- "table_a": {
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "table_a": {
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
- },
- "table_b": {
- "keys": {
- "pk_id": {
- "type": "PK",
+ }
+ },
+ "table_b": {
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "fk_id": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
"columns": ["id"]
- },
- "fk_id": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
}
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH,
validation_source=False
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_id": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == test_metadata
@@ -359,12 +383,30 @@ def test_validate_metadata_of_related_tables_with_fk_key_in_train_process_withou
mock_check_existence_of_source.assert_not_called()
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
- assert mock_validate_referential_integrity.call_count == 1
- mock_check_existence_of_success_file.assert_not_called()
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_id",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ }
+ }
+ }
+ )
+ mock_check_completion_of_training.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -376,72 +418,73 @@ def test_validate_metadata_of_related_tables_with_several_fk_key_in_train_proces
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
rp_logger
):
"""
Test the validation of the metadata of related tables
contained several foreign keys
- used in the training process
+ during the training process
"""
rp_logger.info(
"Test the validation of the metadata of related tables contained several foreign keys "
- "used in the training process"
+ "during the training process"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv"
+ "table_a": {
+ "train_settings": {
+ "source": "path/to/table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
- }
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
}
+ }
+ },
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv"
},
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv"
+ "keys": {
+ "fk_1": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
},
- "keys": {
- "fk_1": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
- },
- "fk_2": {
- "type": "FK",
- "columns": ["name"],
- "references": {
- "table": "table_a",
- "columns": ["name"]
- }
+ "fk_2": {
+ "type": "FK",
+ "columns": ["name"],
+ "references": {
+ "table": "table_a",
+ "columns": ["name"]
}
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_1": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
},
"fk_2": {
- "parent_columns": ["name"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["name"]
}
}
assert validator.merged_metadata == test_metadata
@@ -450,11 +493,11 @@ def test_validate_metadata_of_related_tables_with_several_fk_key_in_train_proces
assert mock_check_existence_of_key_columns.call_count == 2
assert mock_check_existence_of_referenced_columns.call_count == 2
assert mock_validate_referential_integrity.call_count == 2
- mock_check_existence_of_success_file.assert_not_called()
+ mock_check_completion_of_training.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -466,68 +509,69 @@ def test_validate_metadata_of_related_tables_with_several_fk_key_in_train_withou
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
rp_logger
):
"""
Test the validation of the metadata of related tables
contained several foreign keys
- used in the training process
+ during the training process
with 'validation_source' set to 'False'
"""
rp_logger.info(
"Test the validation of the metadata of related tables contained several foreign keys "
- "used in the training process with 'validation_source' set to 'False"
+ "during the training process with 'validation_source' set to 'False"
)
test_metadata = {
- "table_a": {
- "keys": {
- "pk_id": {
- "type": "PK",
+ "table_a": {
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ }
+ },
+ "table_b": {
+ "keys": {
+ "fk_1": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
"columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
}
- }
- },
- "table_b": {
- "keys": {
- "fk_1": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
- },
- "fk_2": {
- "type": "FK",
- "columns": ["name"],
- "references": {
- "table": "table_a",
- "columns": ["name"]
- }
+ },
+ "fk_2": {
+ "type": "FK",
+ "columns": ["name"],
+ "references": {
+ "table": "table_a",
+ "columns": ["name"]
}
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH,
validation_source=False
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_1": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
},
"fk_2": {
- "parent_columns": ["name"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["name"]
}
}
assert validator.merged_metadata == test_metadata
@@ -536,13 +580,13 @@ def test_validate_metadata_of_related_tables_with_several_fk_key_in_train_withou
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
assert mock_validate_referential_integrity.call_count == 2
- mock_check_existence_of_success_file.assert_not_called()
+ mock_check_completion_of_training.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -554,37 +598,38 @@ def test_validate_metadata_of_one_table_without_fk_key_in_infer_process(
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
- mock_check_existence_of_success_file,
mock_check_existence_of_generated_data,
rp_logger
):
"""
Test the validation of the metadata of one table
- contained the primary key used in the inference process
+ contained the primary key during the inference process
"""
rp_logger.info(
"Test the validation of the metadata of one table contained the primary key "
- "used in the inference process"
+ "during the inference process"
)
test_metadata = {
- "test_table": {
- "train_settings": {
- "source": "path/to/test_table.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/test_table.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="infer",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -593,16 +638,16 @@ def test_validate_metadata_of_one_table_without_fk_key_in_infer_process(
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
mock_validate_referential_integrity.assert_not_called()
- mock_check_existence_of_destination.assert_called_once()
- mock_check_existence_of_success_file.assert_not_called()
+ mock_check_completion_of_training.assert_called_once_with("test_table")
+ mock_check_existence_of_destination.assert_called_once_with("test_table")
mock_check_existence_of_generated_data.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_gather_existed_columns")
@@ -610,74 +655,75 @@ def test_validate_metadata_of_related_tables_without_fk_key_in_infer_process(
mock_gather_existed_columns,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
mock_check_existence_of_generated_data,
rp_logger
):
"""
Test the validation of the metadata of related tables
contained only the primary key and the unique key
- used in the inference process
+ during the inference process
"""
rp_logger.info(
"Test the validation of the metadata of related tables "
- "contained only the primary key and the unique key used in the inference process"
+ "contained only the primary key and the unique key during the inference process"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv"
- },
- "infer_settings": {
- "destination": "path/to/generated_table_a.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "table_a": {
+ "train_settings": {
+ "source": "path/to/table_a.csv"
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
+ }
+ },
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv"
},
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv"
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
- }
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="infer",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
mock_gather_existed_columns.assert_not_called()
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
+ assert mock_check_completion_of_training.call_count == 2
assert mock_check_existence_of_destination.call_count == 2
mock_validate_referential_integrity.assert_not_called()
- mock_check_existence_of_success_file.assert_not_called()
mock_check_existence_of_generated_data.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_gather_existed_columns")
@@ -685,83 +731,108 @@ def test_validate_metadata_of_related_tables_with_fk_key_in_infer_process(
mock_gather_existed_columns,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
mock_check_existence_of_generated_data,
rp_logger
):
"""
Test the validation of the metadata of related tables
contained the primary key and the foreign key
- used in the inference process
+ during the inference process
"""
rp_logger.info(
"Test the validation of the metadata of related tables "
- "contained only the primary key and the foreign key used in the inference process"
+ "contained only the primary key and the foreign key during the inference process"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv"
- },
- "infer_settings": {
- "destination": "path/to/generated_table_a.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
- }
+ "table_a": {
+ "train_settings": {
+ "source": "path/to/table_a.csv"
},
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "fk_id": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
- }
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
}
- }
- validator = Validator(
- metadata=test_metadata,
- type_of_process="infer",
- metadata_path=FAKE_METADATA_PATH
+ },
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "fk_id": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ }
+ }
+ }
+ }
+ validator = Validator(
+ metadata=test_metadata,
+ type_of_process="infer",
+ metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_id": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == test_metadata
mock_gather_existed_columns.assert_not_called()
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
+ assert mock_check_completion_of_training.call_count == 2
assert mock_check_existence_of_destination.call_count == 2
- assert mock_validate_referential_integrity.call_count == 1
- mock_check_existence_of_success_file.assert_not_called()
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_id",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv"
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ }
+ }
+ }
+ )
mock_check_existence_of_generated_data.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_gather_existed_columns")
@@ -769,105 +840,105 @@ def test_validate_metadata_of_related_tables_with_several_fk_key_in_infer_proces
mock_gather_existed_columns,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
mock_check_existence_of_generated_data,
rp_logger
):
"""
Test the validation of the metadata of related tables
- contained several foreign keys
- used in the inference process
+ contained several foreign keys during the inference process
"""
rp_logger.info(
"Test the validation of the metadata of related tables contained several foreign keys "
- "used in the inference process"
+ "during the inference process"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv"
- },
- "infer_settings": {
- "destination": "path/to/generated_table_a.csv"
+ "table_a": {
+ "train_settings": {
+ "source": "path/to/table_a.csv"
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
- }
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
}
+ }
+ },
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv"
},
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv"
+ "keys": {
+ "fk_1": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
},
- "keys": {
- "fk_1": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
- },
- "fk_2": {
- "type": "FK",
- "columns": ["name"],
- "references": {
- "table": "table_a",
- "columns": ["name"]
- }
+ "fk_2": {
+ "type": "FK",
+ "columns": ["name"],
+ "references": {
+ "table": "table_a",
+ "columns": ["name"]
}
}
}
}
+ }
validator = Validator(
metadata=test_metadata,
type_of_process="infer",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_1": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
},
"fk_2": {
- "parent_columns": ["name"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["name"]
}
}
assert validator.merged_metadata == test_metadata
mock_gather_existed_columns.assert_not_called()
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
+ assert mock_check_completion_of_training.call_count == 2
assert mock_check_existence_of_destination.call_count == 2
assert mock_validate_referential_integrity.call_count == 2
- mock_check_existence_of_success_file.assert_not_called()
mock_check_existence_of_generated_data.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-def test_validate_incomplete_metadata_contained_fk_key_in_train_process_without_print_report(
+def test_validate_incomplete_metadata_contained_fk_key_in_train_process_without_reports(
mock_gather_existed_columns,
mock_check_existence_of_source,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
mock_check_existence_of_generated_data,
test_metadata_storage,
rp_logger
@@ -875,67 +946,69 @@ def test_validate_incomplete_metadata_contained_fk_key_in_train_process_without_
"""
Test the validation of the incomplete metadata of one table
contained the foreign key but not contained the information of the parent table.
- It's used in the training process with the parameter 'print_report' set to False
+ It's used in the training process without the generation of reports
"""
rp_logger.info(
"Test the validation of the incomplete metadata of one table "
"contained the foreign key but not contained the information of the parent table. "
- "It used in the training process with the parameter 'print_report' set to False"
+ "It used in the training process without the generation of reports"
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- "print_report": False
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ "reports": []
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
+ }
validator = Validator(
metadata=metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == {
"table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv",
- "print_report": True
- },
- "infer_settings": {
- "destination": "path/to/generated_table_a.csv"
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
- }
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
}
},
+ "format": {}
+ },
"table_b": {
"train_settings": {
"source": "path/to/table_b.csv",
- "print_report": False
+ "reports": []
},
"keys": {
"fk_key": {
@@ -950,66 +1023,106 @@ def test_validate_incomplete_metadata_contained_fk_key_in_train_process_without_
}
}
assert mock_gather_existed_columns.call_count == 2
- assert mock_check_existence_of_source.call_count == 2
- assert mock_check_existence_of_key_columns.call_count == 2
- assert mock_check_existence_of_referenced_columns.call_count == 2
- mock_validate_referential_integrity.assert_called_once()
- mock_check_existence_of_success_file.assert_called_once()
+ mock_check_existence_of_source.assert_called_once_with("table_b")
+ mock_check_existence_of_key_columns.assert_called_once_with("table_b")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_key",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ }
+ )
+ mock_check_completion_of_training.assert_called_once_with("table_a")
mock_check_existence_of_generated_data.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
+@pytest.mark.parametrize("value", [
+ ["accuracy", "sample"],
+ ["accuracy", "metrics_only"],
+ ["accuracy"],
+ ["metrics_only"]
+])
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-def test_validate_incomplete_metadata_contained_fk_key_in_train_process_with_print_report(
+def test_validate_incomplete_metadata_contained_fk_key_in_train_process_with_gen_data_and_reports(
mock_gather_existed_columns,
mock_check_existence_of_source,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
mock_check_existence_of_generated_data,
test_metadata_storage,
+ value,
rp_logger
):
"""
Test the validation of the incomplete metadata of one table
contained the foreign key but not contained the information of the parent table.
- It's used in the training process with the parameter 'print_report' set to True
+ It's used in the training process with the generation reports
+ that requires the generation of the synthetic data
"""
rp_logger.info(
"Test the validation of the incomplete metadata of one table "
"contained the foreign key but not contained the information of the parent table. "
- "It used in the training process with the parameter 'print_report' set to True"
+ "It's used in the training process with the generation reports "
+ "that requires the generation of the synthetic data "
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- "print_report": True
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ "reports": value
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
+ }
validator = Validator(
metadata=metadata,
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_key": {
@@ -1019,28 +1132,168 @@ def test_validate_incomplete_metadata_contained_fk_key_in_train_process_with_pri
}
assert validator.merged_metadata == {
"table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv",
- "print_report": True
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "infer_settings": {
- "destination": "path/to/generated_table_a.csv"
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ },
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ "reports": value
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ }
+ }
+ }
+ }
+ assert mock_gather_existed_columns.call_count == 2
+ mock_check_existence_of_source.assert_called_once_with("table_b")
+ mock_check_existence_of_key_columns.assert_called_once_with("table_b")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_key",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ }
+ )
+ mock_check_completion_of_training.assert_called_once_with("table_a")
+ mock_check_existence_of_generated_data.assert_called_once_with("table_a")
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@patch.object(Validator, "_check_existence_of_generated_data")
+@patch.object(Validator, "_check_completion_of_training")
+@patch.object(Validator, "_validate_referential_integrity")
+@patch.object(Validator, "_check_existence_of_referenced_columns")
+@patch.object(Validator, "_check_existence_of_key_columns")
+@patch.object(Validator, "_check_existence_of_source")
+@patch.object(Validator, "_gather_existed_columns")
+def test_validate_incomplete_metadata_contained_fk_key_in_train_process_with_gen_sample_report(
+ mock_gather_existed_columns,
+ mock_check_existence_of_source,
+ mock_check_existence_of_key_columns,
+ mock_check_existence_of_referenced_columns,
+ mock_validate_referential_integrity,
+ mock_check_completion_of_training,
+ mock_check_existence_of_generated_data,
+ test_metadata_storage,
+ rp_logger
+):
+ """
+ Test the validation of the incomplete metadata of one table
+ contained the foreign key but not contained the information of the parent table.
+ It's used in the training process with the generation only a 'sample' report
+ that doesn't require the generation of the synthetic data
+ """
+ rp_logger.info(
+ "Test the validation of the incomplete metadata of one table "
+ "contained the foreign key but not contained the information of the parent table. "
+ "It's used in the training process with the generation only a 'sample' report "
+ "that doesn't require the generation of the synthetic data"
+ )
+ metadata = {
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ "reports": ["sample"]
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
"columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
}
}
+ }
+ }
+ }
+ validator = Validator(
+ metadata=metadata,
+ type_of_process="train",
+ metadata_path=FAKE_METADATA_PATH
+ )
+ validator.errors = dict()
+ validator.run()
+ assert validator.mapping == {
+ "fk_key": {
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
+ }
+ }
+ assert validator.merged_metadata == {
+ "table_a": {
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
},
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ },
"table_b": {
"train_settings": {
"source": "path/to/table_b.csv",
- "print_report": True
+ "reports": ["sample"]
},
"keys": {
"fk_key": {
@@ -1055,17 +1308,47 @@ def test_validate_incomplete_metadata_contained_fk_key_in_train_process_with_pri
}
}
assert mock_gather_existed_columns.call_count == 2
- assert mock_check_existence_of_source.call_count == 2
- assert mock_check_existence_of_key_columns.call_count == 2
- assert mock_check_existence_of_referenced_columns.call_count == 2
- mock_validate_referential_integrity.assert_called_once()
- mock_check_existence_of_success_file.assert_called_once()
- mock_check_existence_of_generated_data.assert_called_once()
+ mock_check_existence_of_source.assert_called_once_with("table_b")
+ mock_check_existence_of_key_columns.assert_called_once_with("table_b")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_key",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ }
+ )
+ mock_check_completion_of_training.assert_called_once_with("table_a")
+ mock_check_existence_of_generated_data.assert_not_called()
rp_logger.info(SUCCESSFUL_MESSAGE)
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_destination")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@@ -1079,7 +1362,7 @@ def test_validate_incomplete_metadata_in_infer_process(
mock_check_existence_of_referenced_columns,
mock_check_existence_of_destination,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
mock_check_existence_of_generated_data,
test_metadata_storage,
rp_logger
@@ -1095,54 +1378,56 @@ def test_validate_incomplete_metadata_in_infer_process(
"It used in the inference process"
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
+ }
validator = Validator(
metadata=metadata,
type_of_process="infer",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == {
"table_a": {
- "train_settings": {
- "source": "path/to/table_a.csv",
- "print_report": True
- },
- "infer_settings": {
- "destination": "path/to/generated_table_a.csv"
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
},
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- },
- "uq_id": {
- "type": "UQ",
- "columns": ["name"]
- }
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
}
},
+ "format": {}
+ },
"table_b": {
"train_settings": {
"source": "path/to/table_b.csv"
@@ -1163,10 +1448,40 @@ def test_validate_incomplete_metadata_in_infer_process(
mock_check_existence_of_source.assert_not_called()
mock_check_existence_of_key_columns.assert_not_called()
mock_check_existence_of_referenced_columns.assert_not_called()
- assert mock_check_existence_of_destination.call_count == 2
- mock_validate_referential_integrity.assert_called_once()
- mock_check_existence_of_success_file.assert_called_once()
- mock_check_existence_of_generated_data.assert_called_once()
+ mock_check_existence_of_destination.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_key",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ }
+ )
+ assert mock_check_completion_of_training.call_args_list == [call("table_b"), call("table_a")]
+ mock_check_existence_of_generated_data.assert_called_once_with("table_a")
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -1174,7 +1489,7 @@ def test_validate_incomplete_metadata_in_infer_process(
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_gather_existed_columns")
-def test_validate_metadata_with_not_existent_source(
+def test_validate_metadata_with_not_existent_source_in_train_process(
mock_gather_existed_columns,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
@@ -1183,26 +1498,26 @@ def test_validate_metadata_with_not_existent_source(
rp_logger
):
"""
- Test the validation of the metadata of one table.
+ Test the validation of the metadata of one table during the training process.
The source of the table is not existent.
"""
rp_logger.info(
- "Test the validation of the metadata of one table. "
+ "Test the validation of the metadata of one table during the training process. "
"The source of the table is not existent."
)
test_metadata = {
- "test_table": {
- "train_settings": {
- "source": "path/to/test_table.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/test_table.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1213,9 +1528,9 @@ def test_validate_metadata_with_not_existent_source(
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
- mock_gather_existed_columns.assert_called_once()
- mock_check_existence_of_key_columns.assert_called_once()
- mock_check_existence_of_referenced_columns.assert_called_once()
+ mock_gather_existed_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_key_columns.assert_called_once_with("test_table")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("test_table")
mock_validate_referential_integrity.assert_not_called()
message = (
"The validation of the metadata has been failed. The error(s) found in - "
@@ -1243,22 +1558,22 @@ def test_validate_incomplete_metadata_with_absent_parent_metadata_in_metadata_st
"The information of the parent table is absent in the metadata storage."
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_c",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_c",
+ "columns": ["id"]
}
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1266,11 +1581,12 @@ def test_validate_incomplete_metadata_with_absent_parent_metadata_in_metadata_st
type_of_process="train",
metadata_path=FAKE_METADATA_PATH
)
+ validator.errors = dict()
validator.run()
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
message = (
@@ -1284,7 +1600,7 @@ def test_validate_incomplete_metadata_with_absent_parent_metadata_in_metadata_st
rp_logger.info(SUCCESSFUL_MESSAGE)
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@@ -1294,7 +1610,7 @@ def test_validate_incomplete_metadata_with_wrong_referential_integrity(
mock_check_existence_of_source,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
test_metadata_storage,
caplog,
rp_logger
@@ -1312,22 +1628,22 @@ def test_validate_incomplete_metadata_with_wrong_referential_integrity(
"which not correspond to the list of the columns of the FK"
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_d",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_d",
+ "columns": ["id"]
}
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1338,8 +1654,8 @@ def test_validate_incomplete_metadata_with_wrong_referential_integrity(
validator.run()
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_d"
+ "parent_table": "table_d",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == {
@@ -1361,7 +1677,7 @@ def test_validate_incomplete_metadata_with_wrong_referential_integrity(
"table_d": {
"train_settings": {
"source": "path/to/table_a.csv",
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"destination": "path/to/generated_table_a.csv"
@@ -1378,7 +1694,7 @@ def test_validate_incomplete_metadata_with_wrong_referential_integrity(
assert mock_check_existence_of_source.call_count == 2
assert mock_check_existence_of_key_columns.call_count == 2
assert mock_check_existence_of_referenced_columns.call_count == 2
- mock_check_existence_of_success_file.assert_called_once()
+ mock_check_completion_of_training.assert_called_once_with("table_d")
message = (
"The validation of the metadata has been failed. "
"The error(s) found in - \"validate referential integrity\": "
@@ -1397,29 +1713,29 @@ def test_validate_metadata_with_not_existent_destination(
rp_logger
):
"""
- Test the validation of the metadata of one table used in the inference process.
+ Test the validation of the metadata of one table during the inference process.
The destination for the generated data is not existent.
"""
rp_logger.info(
- "Test the validation of the metadata of one table used in the inference process. "
+ "Test the validation of the metadata of one table during the inference process. "
"The destination of the table is not existent."
)
test_metadata = {
- "test_table": {
- "train_settings": {
- "source": "path/to/test_table.csv"
- },
- "infer_settings": {
- "destination": "path/to/generated_test_table.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "test_table": {
+ "train_settings": {
+ "source": "path/to/test_table.csv"
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_test_table.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1430,7 +1746,7 @@ def test_validate_metadata_with_not_existent_destination(
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
- mock_validate_referential_integrity.assert_called_once()
+ mock_validate_referential_integrity.assert_called_once_with()
message = (
"The validation of the metadata has been failed. The error(s) found in - "
"\"check existence of the destination\": {\"test_table\": \"It seems that "
@@ -1443,7 +1759,7 @@ def test_validate_metadata_with_not_existent_destination(
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@@ -1455,7 +1771,7 @@ def test_validate_incomplete_metadata_with_absent_success_file_of_parent_table_i
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
mock_check_existence_of_generated_data,
test_metadata_storage,
caplog,
@@ -1463,34 +1779,34 @@ def test_validate_incomplete_metadata_with_absent_success_file_of_parent_table_i
):
"""
Test the validation of the incomplete metadata of one table contained the foreign key
- used in the training process.
+ during the training process.
The information of the parent table is present in the metadata storage,
but the parent table hasn't been trained previously
that's why the success file of the parent table is absent
"""
rp_logger.info(
"Test the validation of the incomplete metadata of one table contained the foreign key "
- "used in the training process. The information of the parent table is present "
+ "during the training process. The information of the parent table is present "
"in the metadata storage, but the parent table hasn't been trained previously "
"that's why the success file of the parent table is absent"
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1500,16 +1816,45 @@ def test_validate_incomplete_metadata_with_absent_success_file_of_parent_table_i
)
validator.run()
assert mock_gather_existed_columns.call_count == 2
- assert mock_check_existence_of_source.call_count == 2
- assert mock_check_existence_of_key_columns.call_count == 2
- assert mock_check_existence_of_referenced_columns.call_count == 2
- mock_validate_referential_integrity.assert_called_once()
- mock_check_existence_of_success_file.assert_called_once()
+ mock_check_existence_of_source.assert_called_once_with("table_b")
+ mock_check_existence_of_key_columns.assert_called_once_with("table_b")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_key",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_metadata={
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ }
+ }
+ )
+ mock_check_completion_of_training.assert_called_once_with("table_a")
mock_check_existence_of_generated_data.assert_not_called()
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == {
@@ -1531,7 +1876,7 @@ def test_validate_incomplete_metadata_with_absent_success_file_of_parent_table_i
"table_a": {
"train_settings": {
"source": "path/to/table_a.csv",
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"destination": "path/to/generated_table_a.csv"
@@ -1550,8 +1895,9 @@ def test_validate_incomplete_metadata_with_absent_success_file_of_parent_table_i
}
message = (
"The validation of the metadata has been failed. The error(s) found in - "
- "\"check existence of the success file\": {\"table_a\": \"The table - 'table_a'"
- "hasn't been trained completely. Please, retrain this table first\"}"
+ "\"check completion of the training process\": "
+ "{\"table_a\": \"The training of the table - 'table_a'"
+ "hasn't been completed. Please, retrain the table.\"}"
)
assert str(error.value) == message
assert message in caplog.text
@@ -1561,7 +1907,9 @@ def test_validate_incomplete_metadata_with_absent_success_file_of_parent_table_i
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_generated_data")
@patch.object(Validator, "_check_existence_of_destination")
+@patch.object(Validator, "_check_completion_of_training")
def test_validate_incomplete_metadata_with_absent_generated_of_parent_table_in_infer_process(
+ mock_check_completion_of_training,
mock_check_existence_of_destination,
mock_check_existence_of_generated_data,
mock_validate_referential_integrity,
@@ -1571,33 +1919,33 @@ def test_validate_incomplete_metadata_with_absent_generated_of_parent_table_in_i
):
"""
Test the validation of the incomplete metadata of one table contained the foreign key
- used in the inference process.
+ during the inference process.
The information of the parent table is present in the metadata storage,
but the generated data of the parent table hasn't been generated previously
"""
rp_logger.info(
"Test the validation of the incomplete metadata of one table contained the foreign key "
- "used in the inference process. The information of the parent table is present "
+ "during the inference process. The information of the parent table is present "
"in the metadata storage, but the generated data of the parent table hasn't been "
"generated previously"
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1609,8 +1957,8 @@ def test_validate_incomplete_metadata_with_absent_generated_of_parent_table_in_i
assert mock_check_existence_of_destination.call_count == 2
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == {
@@ -1632,7 +1980,7 @@ def test_validate_incomplete_metadata_with_absent_generated_of_parent_table_in_i
"table_a": {
"train_settings": {
"source": "path/to/table_a.csv",
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"destination": "path/to/generated_table_a.csv"
@@ -1646,12 +1994,43 @@ def test_validate_incomplete_metadata_with_absent_generated_of_parent_table_in_i
"type": "UQ",
"columns": ["name"]
}
- }
+ },
+ "format": {}
}
}
- assert mock_check_existence_of_destination.call_count == 2
- mock_validate_referential_integrity.assert_called_once()
- mock_check_existence_of_generated_data.assert_called_once()
+ assert mock_check_existence_of_destination.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with(
+ fk_name="fk_key",
+ fk_config={
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
+ }
+ },
+ parent_config={
+ "train_settings": {
+ "source": "path/to/table_a.csv",
+ "reports": ["accuracy", "sample"]
+ },
+ "infer_settings": {
+ "destination": "path/to/generated_table_a.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ },
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["name"]
+ }
+ },
+ "format": {}
+ }
+ )
+ mock_check_existence_of_generated_data.assert_called_once_with("table_a")
message = (
"The validation of the metadata has been failed. "
"The error(s) found in - \"check existence of the generated data\": {"
@@ -1664,55 +2043,65 @@ def test_validate_incomplete_metadata_with_absent_generated_of_parent_table_in_i
rp_logger.info(SUCCESSFUL_MESSAGE)
+@pytest.mark.parametrize("value", [
+ ["accuracy", "sample"],
+ ["accuracy", "metrics_only"],
+ ["sample", "metrics_only"],
+ ["accuracy"],
+ ["metrics_only"],
+])
@patch.object(Validator, "_check_existence_of_generated_data")
-@patch.object(Validator, "_check_existence_of_success_file")
+@patch.object(Validator, "_check_completion_of_training")
@patch.object(Validator, "_validate_referential_integrity")
@patch.object(Validator, "_check_existence_of_referenced_columns")
@patch.object(Validator, "_check_existence_of_key_columns")
@patch.object(Validator, "_check_existence_of_source")
@patch.object(Validator, "_gather_existed_columns")
-def test_validate_incomplete_metadata_without_gen_parent_table_in_train_process_with_print_report(
+def test_validate_incomplete_metadata_without_gen_parent_table_in_train_process_with_reports(
mock_gather_existed_columns,
mock_check_existence_of_source,
mock_check_existence_of_key_columns,
mock_check_existence_of_referenced_columns,
mock_validate_referential_integrity,
- mock_check_existence_of_success_file,
+ mock_check_completion_of_training,
mock_check_existence_of_generated_data,
test_metadata_storage,
caplog,
+ value,
rp_logger
):
"""
Test the validation of the incomplete metadata of one table contained the foreign key
- used in the training process with the parameter 'print_report' set to True.
+ during the training process with the generation of reports that requires
+ the generation of the synthetic data.
The information of the parent table is present in the metadata storage,
but the generated data of the parent table hasn't been generated previously
"""
rp_logger.info(
- "Test the validation of the incomplete metadata of one table contained the foreign key "
- "used in the training process with the parameter 'print_report' set to True. "
+ "Test the validation of the incomplete metadata of one table "
+ "contained the foreign key during the training process with "
+ "the generation of reports that requires the generation of the synthetic data. "
"The information of the parent table is present in the metadata storage, "
"but the generated data of the parent table hasn't been generated previously"
)
metadata = {
- "table_b": {
- "train_settings": {
- "source": "path/to/table_b.csv",
- "print_report": True
- },
- "keys": {
- "fk_key": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "table_b": {
+ "train_settings": {
+ "source": "path/to/table_b.csv",
+ "reports": value
+ },
+ "keys": {
+ "fk_key": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
+ }
with pytest.raises(ValidationError) as error:
with caplog.at_level("ERROR"):
validator = Validator(
@@ -1723,8 +2112,8 @@ def test_validate_incomplete_metadata_without_gen_parent_table_in_train_process_
validator.run()
assert validator.mapping == {
"fk_key": {
- "parent_columns": ["id"],
- "parent_table": "table_a"
+ "parent_table": "table_a",
+ "parent_columns": ["id"]
}
}
assert validator.merged_metadata == {
@@ -1746,7 +2135,7 @@ def test_validate_incomplete_metadata_without_gen_parent_table_in_train_process_
"table_a": {
"train_settings": {
"source": "path/to/table_a.csv",
- "print_report": True
+ "reports": ["accuracy", "sample"]
},
"infer_settings": {
"destination": "path/to/generated_table_a.csv"
@@ -1760,17 +2149,17 @@ def test_validate_incomplete_metadata_without_gen_parent_table_in_train_process_
"type": "UQ",
"columns": ["name"]
}
- }
+ },
+ "format": {}
}
}
assert mock_gather_existed_columns.call_count == 2
- assert mock_check_existence_of_source.call_count == 2
- assert mock_check_existence_of_key_columns.call_count == 2
- assert mock_check_existence_of_referenced_columns.call_count == 2
- mock_check_existence_of_generated_data.assert_called_once()
- mock_check_existence_of_success_file.assert_called_once()
- mock_check_existence_of_generated_data.assert_called_once()
- mock_validate_referential_integrity.assert_called_once()
+ mock_check_existence_of_source.assert_called_once_with("table_b")
+ mock_check_existence_of_key_columns.assert_called_once_with("table_b")
+ mock_check_existence_of_referenced_columns.assert_called_once_with("table_b")
+ mock_validate_referential_integrity.assert_called_once_with("table_b")
+ mock_check_completion_of_training.assert_called_once_with("table_a")
+ mock_check_existence_of_generated_data.assert_called_once_with("table_a")
message = (
"The validation of the metadata has been failed. "
"The error(s) found in - \"check existence of the generated data\": {"
@@ -1791,25 +2180,24 @@ def test_check_not_existent_key_column_in_pk(rp_logger):
"Test that the column of the primary key doesn't exist in the source table"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/"
- "csv_tables/table_with_data.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["non-existent column"]
- }
+ "table_a": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/table_with_data.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["non-existent column"]
}
}
}
- validator = Validator(
- metadata=test_metadata,
- type_of_process="train",
- metadata_path=FAKE_METADATA_PATH
- )
+ }
with pytest.raises(ValidationError) as error:
+ validator = Validator(
+ metadata=test_metadata,
+ type_of_process="train",
+ metadata_path=FAKE_METADATA_PATH
+ )
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -1831,25 +2219,24 @@ def test_check_not_existent_key_column_in_uq(rp_logger):
"Test that the column of the unique key doesn't exist in the source table"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/"
- "csv_tables/table_with_data.csv"
- },
- "keys": {
- "uq_id": {
- "type": "UQ",
- "columns": ["non-existent column"]
- }
+ "table_a": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/table_with_data.csv"
+ },
+ "keys": {
+ "uq_id": {
+ "type": "UQ",
+ "columns": ["non-existent column"]
}
}
}
- validator = Validator(
- metadata=test_metadata,
- type_of_process="train",
- metadata_path=FAKE_METADATA_PATH
- )
+ }
with pytest.raises(ValidationError) as error:
+ validator = Validator(
+ metadata=test_metadata,
+ type_of_process="train",
+ metadata_path=FAKE_METADATA_PATH
+ )
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -1871,41 +2258,40 @@ def test_check_not_existent_key_column_in_fk(rp_logger):
"Test that the column of the foreign key doesn't exist in the child table"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/"
- "csv_tables/table_with_data.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "table_a": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/table_with_data.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
+ }
+ },
+ "table_b": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/"
+ "child_table_with_data.csv"
},
- "table_b": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/"
- "child_table_with_data.csv"
- },
- "keys": {
- "fk_id": {
- "type": "FK",
- "columns": ["non-existent column"],
- "references": {
- "table": "table_a",
- "columns": ["id"]
- }
+ "keys": {
+ "fk_id": {
+ "type": "FK",
+ "columns": ["non-existent column"],
+ "references": {
+ "table": "table_a",
+ "columns": ["id"]
}
}
}
}
- validator = Validator(
- metadata=test_metadata,
- type_of_process="train",
- metadata_path=FAKE_METADATA_PATH
- )
+ }
with pytest.raises(ValidationError) as error:
+ validator = Validator(
+ metadata=test_metadata,
+ type_of_process="train",
+ metadata_path=FAKE_METADATA_PATH
+ )
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -1926,41 +2312,41 @@ def test_check_not_existent_referenced_table_in_fk(test_metadata_storage, rp_log
"Test that the table of the foreign key doesn't exist in the metadata"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/"
- "csv_tables/table_with_data.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
+ "table_a": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/table_with_data.csv"
+ },
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
}
+ }
+ },
+ "table_b": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/"
+ "child_table_with_data.csv"
},
- "table_b": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/"
- "child_table_with_data.csv"
- },
- "keys": {
- "fk_id": {
- "type": "FK",
- "columns": ["non-existent column"],
- "references": {
- "table": "non-existent table",
- "columns": ["id"]
- }
+ "keys": {
+ "fk_id": {
+ "type": "FK",
+ "columns": ["non-existent column"],
+ "references": {
+ "table": "non-existent table",
+ "columns": ["id"]
}
}
}
}
- validator = Validator(
- metadata=test_metadata,
- type_of_process="train",
- metadata_path=FAKE_METADATA_PATH
- )
+ }
with pytest.raises(ValidationError) as error:
+ validator = Validator(
+ metadata=test_metadata,
+ type_of_process="train",
+ metadata_path=FAKE_METADATA_PATH
+ )
+ validator.errors = dict()
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
@@ -1981,41 +2367,40 @@ def test_check_not_existent_referenced_columns_in_fk(rp_logger):
"Test that the referenced columns of the foreign key doesn't exist in the parent table"
)
test_metadata = {
- "table_a": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/"
- "csv_tables/table_with_data.csv"
- },
- "keys": {
- "pk_id": {
- "type": "PK",
- "columns": ["id"]
- }
- }
+ "table_a": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/table_with_data.csv"
},
- "table_b": {
- "train_settings": {
- "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/"
- "child_table_with_data.csv"
+ "keys": {
+ "pk_id": {
+ "type": "PK",
+ "columns": ["id"]
+ }
+ }
+ },
+ "table_b": {
+ "train_settings": {
+ "source": f"{DIR_NAME}/unit/data_loaders/fixtures/csv_tables/"
+ "child_table_with_data.csv"
},
- "keys": {
- "fk_id": {
- "type": "FK",
- "columns": ["id"],
- "references": {
- "table": "table_a",
- "columns": ["non-existent column"]
- }
+ "keys": {
+ "fk_id": {
+ "type": "FK",
+ "columns": ["id"],
+ "references": {
+ "table": "table_a",
+ "columns": ["non-existent column"]
}
}
}
}
- validator = Validator(
- metadata=test_metadata,
- type_of_process="train",
- metadata_path=FAKE_METADATA_PATH
- )
+ }
with pytest.raises(ValidationError) as error:
+ validator = Validator(
+ metadata=test_metadata,
+ type_of_process="train",
+ metadata_path=FAKE_METADATA_PATH
+ )
validator.run()
assert validator.mapping == {}
assert validator.merged_metadata == test_metadata
diff --git a/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_FK_key.yaml b/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_FK_key.yaml
index 48a9adc2..bac9f909 100644
--- a/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_FK_key.yaml
+++ b/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_FK_key.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,7 +11,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
+ reports: all
pk_test:
train_settings:
@@ -27,7 +27,7 @@ fk_test:
source: ./data/fk_test.csv
epochs: 10
drop_null: False
- print_report: true
+ reports: all
batch_size: 200
infer_settings:
@@ -35,7 +35,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_PK_key.yaml b/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_PK_key.yaml
index 9b094869..3346ea1f 100644
--- a/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_PK_key.yaml
+++ b/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_PK_key.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,7 +11,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
+ reports: all
pk_test:
train_settings:
@@ -31,7 +31,7 @@ fk_test:
source: ./data/fk_test.csv
epochs: 10
drop_null: False
- print_report: true
+ reports: all
batch_size: 200
infer_settings:
@@ -39,7 +39,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_UQ_key.yaml b/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_UQ_key.yaml
index d239dbba..db089b41 100644
--- a/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_UQ_key.yaml
+++ b/src/tests/unit/validation_schema/fixtures/metadata_file_with_invalid_UQ_key.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,7 +11,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
+ reports: all
pk_test:
train_settings:
@@ -27,7 +27,7 @@ fk_test:
source: ./data/fk_test.csv
epochs: 10
drop_null: False
- print_report: true
+ reports: all
batch_size: 200
infer_settings:
@@ -35,7 +35,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/metadata_file_without_required_fields.yaml b/src/tests/unit/validation_schema/fixtures/metadata_file_without_required_fields.yaml
index 23e85825..84a315f9 100644
--- a/src/tests/unit/validation_schema/fixtures/metadata_file_without_required_fields.yaml
+++ b/src/tests/unit/validation_schema/fixtures/metadata_file_without_required_fields.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,7 +11,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
+ reports: all
pk_test:
keys:
@@ -24,7 +24,7 @@ fk_test:
train_settings:
epochs: 10
drop_null: False
- print_report: true
+ reports: all
batch_size: 200
infer_settings:
@@ -32,7 +32,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/metadata_file_without_sources.yaml b/src/tests/unit/validation_schema/fixtures/metadata_file_without_sources.yaml
index 75d2bd34..713a6751 100644
--- a/src/tests/unit/validation_schema/fixtures/metadata_file_without_sources.yaml
+++ b/src/tests/unit/validation_schema/fixtures/metadata_file_without_sources.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,11 +11,9 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
- get_infer_metrics: false
+ reports: all
pk_test:
- train_settings:
keys:
pk_test_pk_id:
type: "PK"
@@ -26,7 +24,7 @@ fk_test:
train_settings:
epochs: 10
drop_null: False
- print_report: true
+ reports: all
batch_size: 200
format:
@@ -44,8 +42,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
- get_infer_metrics: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/metadata_file_without_training_settings.yaml b/src/tests/unit/validation_schema/fixtures/metadata_file_without_training_settings.yaml
index f9079bc2..74ad857c 100644
--- a/src/tests/unit/validation_schema/fixtures/metadata_file_without_training_settings.yaml
+++ b/src/tests/unit/validation_schema/fixtures/metadata_file_without_training_settings.yaml
@@ -5,8 +5,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
- get_infer_metrics: false
+ reports: all
pk_test:
keys:
@@ -31,8 +30,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
- get_infer_metrics: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/valid_metadata_file.yaml b/src/tests/unit/validation_schema/fixtures/valid_metadata_file.yaml
index 379c08a9..65794ed0 100644
--- a/src/tests/unit/validation_schema/fixtures/valid_metadata_file.yaml
+++ b/src/tests/unit/validation_schema/fixtures/valid_metadata_file.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,8 +11,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
- get_infer_metrics: false
+ reports: all
pk_test:
train_settings:
@@ -28,7 +27,10 @@ fk_test:
source: ./data/fk_test.csv
epochs: 10
drop_null: False
- print_report: true
+ reports:
+ - accuracy
+ - sample
+ - metrics_only
batch_size: 200
format:
@@ -46,8 +48,9 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
- get_infer_metrics: false
+ reports:
+ - accuracy
+ - metrics_only
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/fixtures/valid_metadata_file_for_excel_table.yaml b/src/tests/unit/validation_schema/fixtures/valid_metadata_file_for_excel_table.yaml
index 065bcdd3..79929a12 100644
--- a/src/tests/unit/validation_schema/fixtures/valid_metadata_file_for_excel_table.yaml
+++ b/src/tests/unit/validation_schema/fixtures/valid_metadata_file_for_excel_table.yaml
@@ -3,7 +3,7 @@ global:
epochs: 10
drop_null: False
row_limit: 500
- print_report: false
+ reports: none
batch_size: 100
infer_settings: # Settings for infer process
@@ -11,7 +11,7 @@ global:
run_parallel: False
batch_size: 100
random_seed: 5
- print_report: true
+ reports: all
pk_test:
train_settings:
diff --git a/src/tests/unit/validation_schema/fixtures/valid_metadata_file_with_absent_global_settings.yaml b/src/tests/unit/validation_schema/fixtures/valid_metadata_file_with_absent_global_settings.yaml
index e8164f6b..deb7dc2b 100644
--- a/src/tests/unit/validation_schema/fixtures/valid_metadata_file_with_absent_global_settings.yaml
+++ b/src/tests/unit/validation_schema/fixtures/valid_metadata_file_with_absent_global_settings.yaml
@@ -12,7 +12,7 @@ fk_test:
source: ./data/fk_test.csv
epochs: 10
drop_null: False
- print_report: true
+ reports: all
batch_size: 200
infer_settings:
@@ -20,7 +20,7 @@ fk_test:
run_parallel: False
batch_size: 200
random_seed: 1
- print_report: false
+ reports: none
keys:
fk_test_pk_id:
diff --git a/src/tests/unit/validation_schema/test_validation_schema.py b/src/tests/unit/validation_schema/test_validation_schema.py
index 5d250254..8d5e1614 100644
--- a/src/tests/unit/validation_schema/test_validation_schema.py
+++ b/src/tests/unit/validation_schema/test_validation_schema.py
@@ -1,27 +1,58 @@
import pytest
-from typing import Dict
-import yaml
-from yaml import Loader
from marshmallow import ValidationError
from syngen.ml.validation_schema import ValidationSchema
+from syngen.ml.data_loaders import MetadataLoader
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME
-def load_metadata_file(metadata_path) -> Dict:
- with open(metadata_path, "r", encoding="utf-8") as metadata_file:
- metadata = yaml.load(metadata_file, Loader=Loader)
- return metadata
-
-
def test_valid_metadata_file(rp_logger, caplog):
rp_logger.info("Test the validation of the schema of the valid metadata file")
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
+ with caplog.at_level(level="DEBUG"):
+ ValidationSchema(
+ metadata=metadata,
+ metadata_path=path_to_metadata,
+ validation_source=True,
+ process="train"
+ ).validate_schema()
+ assert "The schema of the metadata is valid" in caplog.text
+ rp_logger.info(SUCCESSFUL_MESSAGE)
+
+
+@pytest.mark.parametrize(
+ "type_of_process, reports", [
+ ("train", []),
+ ("infer", []),
+ ("train", ["accuracy", "sample"]),
+ ("infer", ["accuracy"]),
+ ("train", ["accuracy"]),
+ ("infer", ["accuracy"]),
+ ("train", ["sample"]),
+ ("train", ["metrics_only"]),
+ ("infer", ["metrics_only"]),
+ ("train", ["accuracy", "metrics_only"]),
+ ("infer", ["accuracy", "metrics_only"]),
+ ("train", ["sample", "metrics_only"])
+ ]
+)
+def test_valid_metadata_file_with_diff_types_of_reports(
+ type_of_process, reports, rp_logger, caplog
+):
+ rp_logger.info(
+ "Test the validation of the schema of the valid metadata file "
+ f"with reports - {', '.join(reports)} during the {type_of_process} process"
+ )
+ path_to_metadata = (
+ f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
+ )
+ metadata = MetadataLoader(path_to_metadata).load_data()
+ metadata["global"]["train_settings"]["reports"] = reports
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -44,7 +75,7 @@ def test_valid_metadata_file_with_source_contained_path_to_excel_table(
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"valid_metadata_file_for_excel_table.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -65,7 +96,7 @@ def test_valid_metadata_file_without_global_settings(rp_logger, caplog):
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"valid_metadata_file_with_absent_global_settings.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -86,7 +117,7 @@ def test_valid_metadata_file_only_with_required_fields(rp_logger, caplog):
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"valid_metadata_file_only_with_required_fields.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -142,9 +173,9 @@ def test_valid_metadata_file_only_with_required_fields(rp_logger, caplog):
"{'batch_size': ['Not a valid integer.']}}}",
),
(
- {"print_report": "not a valid type of a value"},
+ {"reports": "not a valid type of a value"},
"The details are - {'fk_test': {'train_settings': "
- "{'print_report': ['Not a valid boolean.']}}}",
+ "{'reports': ['Invalid value.']}}}",
),
(
{"column_types": {"invalid_type": ["column_1", "column_2"]}},
@@ -152,11 +183,6 @@ def test_valid_metadata_file_only_with_required_fields(rp_logger, caplog):
"defaultdict(, {'invalid_type': {"
"'key': ['Must be one of: categorical.']}})}}}",
),
- (
- {"get_infer_metrics": "invalid parameter"},
- "The details are - {'fk_test': {'train_settings': {"
- "'get_infer_metrics': ['Unknown field.']}}}",
- ),
],
)
def test_metadata_file_with_invalid_training_settings(
@@ -168,7 +194,7 @@ def test_metadata_file_with_invalid_training_settings(
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
metadata["fk_test"]["train_settings"].update(wrong_setting)
with pytest.raises(ValidationError) as error:
ValidationSchema(
@@ -223,14 +249,9 @@ def test_metadata_file_with_invalid_training_settings(
"{'batch_size': ['Not a valid integer.']}}}",
),
(
- {"print_report": "not a valid type of a value"},
+ {"reports": "not a valid type of a value"},
"The details are - {'global': {'train_settings': "
- "{'print_report': ['Not a valid boolean.']}}}",
- ),
- (
- {"get_infer_metrics": "invalid parameter"},
- "The details are - {'global': {'train_settings': {"
- "'get_infer_metrics': ['Unknown field.']}}}",
+ "{'reports': ['Invalid value.']}}}",
),
],
)
@@ -243,7 +264,7 @@ def test_metadata_file_with_invalid_global_training_settings(
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
metadata["global"]["train_settings"].update(wrong_setting)
with pytest.raises(ValidationError) as error:
ValidationSchema(
@@ -293,14 +314,9 @@ def test_metadata_file_with_invalid_global_training_settings(
"{'random_seed': ['Not a valid integer.']}}}",
),
(
- {"print_report": "not a valid type of a value"},
- "The details are - {'fk_test': {'infer_settings': {"
- "'print_report': ['Not a valid boolean.']}}}",
- ),
- (
- {"get_infer_metrics": "not a valid type of a value"},
+ {"reports": "not a valid type of a value"},
"The details are - {'fk_test': {'infer_settings': {"
- "'get_infer_metrics': ['Not a valid boolean.']}}}",
+ "'reports': ['Invalid value.']}}}",
),
],
)
@@ -311,7 +327,7 @@ def test_metadata_file_with_invalid_infer_settings(
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
metadata["fk_test"]["infer_settings"].update(wrong_setting)
with pytest.raises(ValidationError) as error:
ValidationSchema(
@@ -356,14 +372,9 @@ def test_metadata_file_with_invalid_infer_settings(
"{'random_seed': ['Not a valid integer.']}}}",
),
(
- {"print_report": "not a valid type of a value"},
- "The details are - {'global': {'infer_settings': {"
- "'print_report': ['Not a valid boolean.']}}}",
- ),
- (
- {"get_infer_metrics": "not a valid type of a value"},
+ {"reports": "not a valid type of a value"},
"The details are - {'global': {'infer_settings': {"
- "'get_infer_metrics': ['Not a valid boolean.']}}}",
+ "'reports': ['Invalid value.']}}}",
),
],
)
@@ -376,7 +387,7 @@ def test_metadata_file_with_invalid_global_infer_settings(
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
metadata["global"]["infer_settings"].update(wrong_setting)
with pytest.raises(ValidationError) as error:
ValidationSchema(
@@ -465,7 +476,7 @@ def test_metadata_file_with_invalid_format_settings_for_csv_table(
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/valid_metadata_file.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
metadata["fk_test"]["format"].update(wrong_setting)
with pytest.raises(ValidationError) as error:
ValidationSchema(
@@ -529,7 +540,7 @@ def test_metadata_file_with_invalid_format_settings_for_excel_table(
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"valid_metadata_file_for_excel_table.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
metadata["pk_test"]["format"].update(wrong_setting)
with pytest.raises(ValidationError) as error:
ValidationSchema(
@@ -551,7 +562,7 @@ def test_metadata_file_with_absent_required_fields(rp_logger):
)
path_to_metadata = (f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_required_fields.yaml")
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as error:
ValidationSchema(
metadata=metadata,
@@ -561,8 +572,8 @@ def test_metadata_file_with_absent_required_fields(rp_logger):
).validate_schema()
assert str(error.value) == (
"Validation error(s) found in the schema of the metadata. "
- "The details are - {'pk_test': {'train_settings': ["
- "'Missing data for required field.']}, 'fk_test': {"
+ "The details are - {'pk_test': {'train_settings': "
+ "{'source': ['Missing data for required field.']}}, 'fk_test': {"
"'train_settings': {'source': ['Missing data for required field.']}}}"
)
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -576,7 +587,7 @@ def test_metadata_file_with_invalid_PK_key_contained_references_section(rp_logge
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/metadata_file_with_invalid_PK_key.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as error:
ValidationSchema(
metadata=metadata,
@@ -601,7 +612,7 @@ def test_metadata_file_with_invalid_UQ_key_contained_references_section(rp_logge
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/metadata_file_with_invalid_UQ_key.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as error:
ValidationSchema(
metadata=metadata,
@@ -626,7 +637,7 @@ def test_metadata_file_with_invalid_FK_key_without_references_section(rp_logger)
path_to_metadata = (
f"{DIR_NAME}/unit/validation_schema/fixtures/metadata_file_with_invalid_FK_key.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as error:
ValidationSchema(
metadata=metadata,
@@ -711,7 +722,7 @@ def test_validation_schema_of_keys(rp_logger, path_to_metadata, expected_error):
rp_logger.info(
"Test the validation of the schema of the metadata file with invalid section 'keys'"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as error:
ValidationSchema(
metadata=metadata,
@@ -735,7 +746,7 @@ def test_valid_metadata_file_without_sources_during_training_process_without_val
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_sources.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -758,7 +769,7 @@ def test_valid_metadata_file_without_sources_during_training_process_with_valida
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_sources.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as e:
ValidationSchema(
metadata=metadata,
@@ -768,7 +779,8 @@ def test_valid_metadata_file_without_sources_during_training_process_with_valida
).validate_schema()
assert str(e.value) == (
"Validation error(s) found in the schema of the metadata. "
- "The details are - {'pk_test': {'train_settings': ['Field may not be null.']}, "
+ "The details are - {'pk_test': {'train_settings': {"
+ "'source': ['Missing data for required field.']}}, "
"'fk_test': {'train_settings': {'source': ['Missing data for required field.']}}}"
)
rp_logger.info(SUCCESSFUL_MESSAGE)
@@ -785,7 +797,7 @@ def test_valid_metadata_file_without_sources_during_infer_process_without_valida
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_sources.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -808,7 +820,7 @@ def test_valid_metadata_file_without_sources_during_inference_process_with_valid
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_sources.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -831,7 +843,7 @@ def test_valid_metadata_file_without_training_settings_during_train_process_with
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_training_settings.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -854,7 +866,7 @@ def test_valid_metadata_file_without_training_settings_during_train_process_with
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_training_settings.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with pytest.raises(ValidationError) as e:
ValidationSchema(
metadata=metadata,
@@ -882,7 +894,7 @@ def test_valid_metadata_file_without_training_settings_during_infer_process_with
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_training_settings.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,
@@ -906,7 +918,7 @@ def test_valid_metadata_file_without_training_settings_during_infer_process_with
f"{DIR_NAME}/unit/validation_schema/fixtures/"
"metadata_file_without_training_settings.yaml"
)
- metadata = load_metadata_file(path_to_metadata)
+ metadata = MetadataLoader(path_to_metadata).load_data()
with caplog.at_level(level="DEBUG"):
ValidationSchema(
metadata=metadata,