From 4425864f44aa33907d665e71a57707051510d18c Mon Sep 17 00:00:00 2001 From: Hanna Imshenetska <69158110+Anna050689@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:09:27 +0200 Subject: [PATCH] fix the process of saving the data in ".avro" format without providing the schema (#484) * refactor 'ml/convertor', 'ml/data_loaders', fix the bug of saving the data in the '.avro' format without providing the schema * refactor the code in 'ml/data_loaders', 'ml/convertor' * update unit tests in 'tests/unit/convertors' * update 'VERSION' * refactor the code in 'ml/data_loaders' * refactor the code in 'ml/convertor' * update unit tests in 'tests/unit/data_loaders' * update 'VERSION' --------- Co-authored-by: Hanna Imshenetska --- src/syngen/VERSION | 2 +- src/syngen/ml/convertor/convertor.py | 38 ++++++--- src/syngen/ml/data_loaders/data_loaders.py | 14 ++-- src/tests/unit/convertors/test_convertors.py | 83 +++++++++++++++++++ .../unit/data_loaders/test_data_loaders.py | 19 +++++ 5 files changed, 138 insertions(+), 18 deletions(-) diff --git a/src/syngen/VERSION b/src/syngen/VERSION index 57121573..5eef0f10 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.10.1 +0.10.2 diff --git a/src/syngen/ml/convertor/convertor.py b/src/syngen/ml/convertor/convertor.py index 8f8de4ae..d89a8b79 100644 --- a/src/syngen/ml/convertor/convertor.py +++ b/src/syngen/ml/convertor/convertor.py @@ -1,5 +1,6 @@ -from typing import Dict +from typing import Dict, Tuple from dataclasses import dataclass +from datetime import datetime, date import pandas as pd import numpy as np @@ -13,9 +14,18 @@ class Convertor: """ schema: Dict df: pd.DataFrame + excluded_dtypes: Tuple = (str, bytes, datetime, date) - @staticmethod - def _update_data_types(schema: Dict, df: pd.DataFrame): + def _check_dtype_or_nan(self, dtypes: Tuple): + """ + Check if the value is of the specified data types or 'np.NaN' + """ + return ( + lambda x: isinstance(x, dtypes) + or (not isinstance(x, self.excluded_dtypes) and np.isnan(x)) + ) + + def _update_data_types(self, schema: Dict, df: pd.DataFrame): """ Update data types related to the fetched schema """ @@ -36,6 +46,15 @@ def _update_data_types(schema: Dict, df: pd.DataFrame): f"isn\'t correct for the column - '{column}' as it's not empty" ) + if not schema.get("fields"): + for column in df.columns: + if df[column].apply(lambda x: isinstance(x, int)).all(): + df[column] = df[column].astype(int) + elif df[column].apply(self._check_dtype_or_nan(dtypes=(int, float))).all(): + df[column] = df[column].astype(float) + elif df[column].apply(self._check_dtype_or_nan(dtypes=(str, bytes))).all(): + df[column] = df[column].astype(pd.StringDtype()) + @staticmethod def _set_none_values_to_nan(df: pd.DataFrame): """ @@ -51,8 +70,7 @@ def _set_none_values_to_nan(df: pd.DataFrame): ] return df - @staticmethod - def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame: + def _cast_values_to_string(self, df: pd.DataFrame) -> pd.DataFrame: """ Cast the values contained in columns with the data type 'object' to 'string' @@ -61,7 +79,7 @@ def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame: for column in df_object_subset: df[column] = [ i - if not isinstance(i, str) and np.isnan(i) + if not isinstance(i, self.excluded_dtypes) and np.isnan(i) else str(i) for i in df[column] ] @@ -73,14 +91,13 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame: """ if not df.empty: try: + df = self._set_none_values_to_nan(df) + df = self._cast_values_to_string(df) self._update_data_types(schema, df) + return df except Exception as e: logger.error(e) raise e - else: - df = self._set_none_values_to_nan(df) - df = self._cast_values_to_string(df) - return df else: return df @@ -114,6 +131,7 @@ def _convert_schema(schema) -> Dict: """ converted_schema = dict() converted_schema["fields"] = dict() + schema = schema if schema else dict() for column, data_type in schema.items(): fields = converted_schema["fields"] if "int" in data_type or "long" in data_type or "boolean" in data_type: diff --git a/src/syngen/ml/data_loaders/data_loaders.py b/src/syngen/ml/data_loaders/data_loaders.py index 74f7128f..50479fdf 100644 --- a/src/syngen/ml/data_loaders/data_loaders.py +++ b/src/syngen/ml/data_loaders/data_loaders.py @@ -283,7 +283,7 @@ def _load_data(self) -> pd.DataFrame: return pdx.from_avro(f) @staticmethod - def _get_preprocessed_schema(schema: Dict) -> Dict: + def _get_preprocessed_schema(schema: Optional[Dict]) -> Optional[Dict]: """ Get the preprocessed schema """ @@ -293,6 +293,7 @@ def _get_preprocessed_schema(schema: Dict) -> Dict: for field in schema.get("fields", {}) } + return schema def load_data(self, **kwargs) -> Tuple[pd.DataFrame, Dict]: """ @@ -301,7 +302,7 @@ def load_data(self, **kwargs) -> Tuple[pd.DataFrame, Dict]: try: df = self._load_data() schema = self.load_schema() - return self._preprocess_schema_and_df(schema, df) + return self._get_schema_and_df(schema, df) except FileNotFoundError as error: message = ( f"It seems that the path to the table isn't valid.\n" @@ -321,10 +322,9 @@ def _save_data(self, df: pd.DataFrame, schema: Optional[Dict]): def save_data(self, df: pd.DataFrame, schema: Optional[Dict] = None, **kwargs): if schema is not None: logger.trace(f"The data will be saved with the schema: {schema}") - preprocessed_schema = ( - self._get_preprocessed_schema(schema) if schema is not None else schema - ) - df = AvroConvertor(preprocessed_schema, df).preprocessed_df + + preprocessed_schema = self._get_preprocessed_schema(schema) + df = AvroConvertor(preprocessed_schema, df).preprocessed_df self._save_data(df, schema) def __load_original_schema(self): @@ -351,7 +351,7 @@ def load_schema(self) -> Dict[str, str]: return self._get_preprocessed_schema(original_schema) @staticmethod - def _preprocess_schema_and_df( + def _get_schema_and_df( schema: Dict[str, str], df: pd.DataFrame ) -> Tuple[pd.DataFrame, Dict[str, str]]: """ diff --git a/src/tests/unit/convertors/test_convertors.py b/src/tests/unit/convertors/test_convertors.py index 7f73a81e..557be0dc 100644 --- a/src/tests/unit/convertors/test_convertors.py +++ b/src/tests/unit/convertors/test_convertors.py @@ -167,6 +167,89 @@ def test_initiate_avro_convertor(rp_logger): rp_logger.info(SUCCESSFUL_MESSAGE) +def test_initiate_avro_convertor_without_provided_schema(rp_logger): + rp_logger.info("Initiating the instance of the class AvroConvertor without a provided schema") + df, schema = DataLoader( + f"{DIR_NAME}/unit/convertors/fixtures/csv_tables/table_with_diff_data_types.csv" + ).load_data() + + convertor = AvroConvertor(schema=None, df=df) + + assert df.dtypes.to_dict() == { + "employeekey": dtype("int64"), + "parentemployeekey": dtype("float64"), + "parentemployeenationalidalternatekey": dtype("float64"), + "employeenationalidalternatekey": pd.StringDtype(), + "salesterritorykey": dtype("int64"), + "firstname": pd.StringDtype(), + "lastname": pd.StringDtype(), + "middlename": pd.StringDtype(), + "namestyle": dtype("int64"), + "title": pd.StringDtype(), + "hiredate": pd.StringDtype(), + "birthdate": pd.StringDtype(), + "loginid": pd.StringDtype(), + "emailaddress": pd.StringDtype(), + "phone": pd.StringDtype(), + "maritalstatus": pd.StringDtype(), + "emergencycontactname": pd.StringDtype(), + "emergencycontactphone": pd.StringDtype(), + "salariedflag": dtype("int64"), + "gender": pd.StringDtype(), + "payfrequency": dtype("int64"), + "baserate": dtype("float64"), + "vacationhours": dtype("int64"), + "sickleavehours": dtype("int64"), + "currentflag": dtype("int64"), + "salespersonflag": dtype("int64"), + "departmentname": pd.StringDtype(), + "startdate": pd.StringDtype(), + "enddate": pd.StringDtype(), + "status": pd.StringDtype(), + "employeephoto": pd.StringDtype(), + } + + assert convertor.converted_schema == { + "fields": {}, + "format": "Avro" + } + assert convertor.preprocessed_df.dtypes.to_dict() == { + "employeekey": dtype("int64"), + "parentemployeekey": dtype("float64"), + "parentemployeenationalidalternatekey": dtype("float64"), + "employeenationalidalternatekey": pd.StringDtype(), + "salesterritorykey": dtype("int64"), + "firstname": pd.StringDtype(), + "lastname": pd.StringDtype(), + "middlename": pd.StringDtype(), + "namestyle": dtype("int64"), + "title": pd.StringDtype(), + "hiredate": pd.StringDtype(), + "birthdate": pd.StringDtype(), + "loginid": pd.StringDtype(), + "emailaddress": pd.StringDtype(), + "phone": pd.StringDtype(), + "maritalstatus": pd.StringDtype(), + "emergencycontactname": pd.StringDtype(), + "emergencycontactphone": pd.StringDtype(), + "salariedflag": dtype("int64"), + "gender": pd.StringDtype(), + "payfrequency": dtype("int64"), + "baserate": dtype("float64"), + "vacationhours": dtype("int64"), + "sickleavehours": dtype("int64"), + "currentflag": dtype("int64"), + "salespersonflag": dtype("int64"), + "departmentname": pd.StringDtype(), + "startdate": pd.StringDtype(), + "enddate": pd.StringDtype(), + "status": pd.StringDtype(), + "employeephoto": pd.StringDtype(), + } + pd.testing.assert_series_equal(convertor.preprocessed_df.dtypes, df.dtypes) + rp_logger.info(SUCCESSFUL_MESSAGE) + + def test_initiate_avro_convertor_if_schema_contains_unsupported_data_type(caplog, rp_logger): rp_logger.info( "Initiating the instance of the class AvroConvertor " diff --git a/src/tests/unit/data_loaders/test_data_loaders.py b/src/tests/unit/data_loaders/test_data_loaders.py index 98922e2f..c9c3a900 100644 --- a/src/tests/unit/data_loaders/test_data_loaders.py +++ b/src/tests/unit/data_loaders/test_data_loaders.py @@ -419,6 +419,25 @@ def test_save_data_in_avro_format(test_avro_path, test_df, test_avro_schema, rp_ rp_logger.info(SUCCESSFUL_MESSAGE) +def test_save_data_in_avro_format_without_provided_schema( + test_avro_path, test_df, rp_logger +): + rp_logger.info("Saving data in avro format locally without a provided schema") + data_loader = DataLoader(test_avro_path) + data_loader.save_data(test_df, schema=None) + + assert isinstance(data_loader.file_loader, AvroLoader) + assert os.path.exists(test_avro_path) is True + + loaded_df, schema = data_loader.load_data() + pd.testing.assert_frame_equal(loaded_df, test_df) + assert schema == { + "fields": {"gender": "int", "height": "float", "id": "int"}, + "format": "Avro", + } + rp_logger.info(SUCCESSFUL_MESSAGE) + + def test_load_data_from_table_in_pickle_format(rp_logger): rp_logger.info("Loading data from local table in pickle format") data_loader = DataLoader(f"{DIR_NAME}/unit/data_loaders/fixtures/"