diff --git a/src/syngen/VERSION b/src/syngen/VERSION index a7c0a5c9..137faca0 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.10.2rc3 +0.10.2rc4 diff --git a/src/syngen/ml/convertor/convertor.py b/src/syngen/ml/convertor/convertor.py index 4a3d1b0f..d89a8b79 100644 --- a/src/syngen/ml/convertor/convertor.py +++ b/src/syngen/ml/convertor/convertor.py @@ -1,6 +1,6 @@ from typing import Dict, Tuple from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, date import pandas as pd import numpy as np @@ -14,15 +14,15 @@ class Convertor: """ schema: Dict df: pd.DataFrame + excluded_dtypes: Tuple = (str, bytes, datetime, date) - @staticmethod - def _check_dtype_or_nan(dtypes: Tuple): + def _check_dtype_or_nan(self, dtypes: Tuple): """ - Check if the value is of the specified data type or 'np.NaN' + Check if the value is of the specified data types or 'np.NaN' """ return ( lambda x: isinstance(x, dtypes) - or (not isinstance(x, (str, bytes, datetime)) and np.isnan(x)) + or (not isinstance(x, self.excluded_dtypes) and np.isnan(x)) ) def _update_data_types(self, schema: Dict, df: pd.DataFrame): @@ -70,8 +70,7 @@ def _set_none_values_to_nan(df: pd.DataFrame): ] return df - @staticmethod - def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame: + def _cast_values_to_string(self, df: pd.DataFrame) -> pd.DataFrame: """ Cast the values contained in columns with the data type 'object' to 'string' @@ -80,7 +79,7 @@ def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame: for column in df_object_subset: df[column] = [ i - if not isinstance(i, (str, bytes, datetime)) and np.isnan(i) + if not isinstance(i, self.excluded_dtypes) and np.isnan(i) else str(i) for i in df[column] ] @@ -95,11 +94,10 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame: df = self._set_none_values_to_nan(df) df = self._cast_values_to_string(df) self._update_data_types(schema, df) + return df except Exception as e: logger.error(e) raise e - else: - return df else: return df diff --git a/src/tests/unit/data_loaders/test_data_loaders.py b/src/tests/unit/data_loaders/test_data_loaders.py index 98922e2f..c9c3a900 100644 --- a/src/tests/unit/data_loaders/test_data_loaders.py +++ b/src/tests/unit/data_loaders/test_data_loaders.py @@ -419,6 +419,25 @@ def test_save_data_in_avro_format(test_avro_path, test_df, test_avro_schema, rp_ rp_logger.info(SUCCESSFUL_MESSAGE) +def test_save_data_in_avro_format_without_provided_schema( + test_avro_path, test_df, rp_logger +): + rp_logger.info("Saving data in avro format locally without a provided schema") + data_loader = DataLoader(test_avro_path) + data_loader.save_data(test_df, schema=None) + + assert isinstance(data_loader.file_loader, AvroLoader) + assert os.path.exists(test_avro_path) is True + + loaded_df, schema = data_loader.load_data() + pd.testing.assert_frame_equal(loaded_df, test_df) + assert schema == { + "fields": {"gender": "int", "height": "float", "id": "int"}, + "format": "Avro", + } + rp_logger.info(SUCCESSFUL_MESSAGE) + + def test_load_data_from_table_in_pickle_format(rp_logger): rp_logger.info("Loading data from local table in pickle format") data_loader = DataLoader(f"{DIR_NAME}/unit/data_loaders/fixtures/"