Skip to content

Commit

Permalink
fix the process of saving the data in ".avro" format without providin…
Browse files Browse the repository at this point in the history
…g the schema (#484)

* refactor 'ml/convertor', 'ml/data_loaders', fix the bug of saving the data in the '.avro' format without providing the schema

* refactor the code in 'ml/data_loaders', 'ml/convertor'

* update unit tests in 'tests/unit/convertors'

* update 'VERSION'

* refactor the code in 'ml/data_loaders'

* refactor the code in 'ml/convertor'

* update unit tests in 'tests/unit/data_loaders'

* update 'VERSION'

---------

Co-authored-by: Hanna Imshenetska <[email protected]@EVZZAMZSA0021.epam.com>
  • Loading branch information
Anna050689 and Hanna Imshenetska authored Jan 6, 2025
1 parent 47ced85 commit 4425864
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.1
0.10.2
38 changes: 28 additions & 10 deletions src/syngen/ml/convertor/convertor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, Tuple
from dataclasses import dataclass
from datetime import datetime, date

import pandas as pd
import numpy as np
Expand All @@ -13,9 +14,18 @@ class Convertor:
"""
schema: Dict
df: pd.DataFrame
excluded_dtypes: Tuple = (str, bytes, datetime, date)

@staticmethod
def _update_data_types(schema: Dict, df: pd.DataFrame):
def _check_dtype_or_nan(self, dtypes: Tuple):
"""
Check if the value is of the specified data types or 'np.NaN'
"""
return (
lambda x: isinstance(x, dtypes)
or (not isinstance(x, self.excluded_dtypes) and np.isnan(x))
)

def _update_data_types(self, schema: Dict, df: pd.DataFrame):
"""
Update data types related to the fetched schema
"""
Expand All @@ -36,6 +46,15 @@ def _update_data_types(schema: Dict, df: pd.DataFrame):
f"isn\'t correct for the column - '{column}' as it's not empty"
)

if not schema.get("fields"):
for column in df.columns:
if df[column].apply(lambda x: isinstance(x, int)).all():
df[column] = df[column].astype(int)
elif df[column].apply(self._check_dtype_or_nan(dtypes=(int, float))).all():
df[column] = df[column].astype(float)
elif df[column].apply(self._check_dtype_or_nan(dtypes=(str, bytes))).all():
df[column] = df[column].astype(pd.StringDtype())

@staticmethod
def _set_none_values_to_nan(df: pd.DataFrame):
"""
Expand All @@ -51,8 +70,7 @@ def _set_none_values_to_nan(df: pd.DataFrame):
]
return df

@staticmethod
def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame:
def _cast_values_to_string(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Cast the values contained in columns with the data type 'object'
to 'string'
Expand All @@ -61,7 +79,7 @@ def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame:
for column in df_object_subset:
df[column] = [
i
if not isinstance(i, str) and np.isnan(i)
if not isinstance(i, self.excluded_dtypes) and np.isnan(i)
else str(i)
for i in df[column]
]
Expand All @@ -73,14 +91,13 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame:
"""
if not df.empty:
try:
df = self._set_none_values_to_nan(df)
df = self._cast_values_to_string(df)
self._update_data_types(schema, df)
return df
except Exception as e:
logger.error(e)
raise e
else:
df = self._set_none_values_to_nan(df)
df = self._cast_values_to_string(df)
return df
else:
return df

Expand Down Expand Up @@ -114,6 +131,7 @@ def _convert_schema(schema) -> Dict:
"""
converted_schema = dict()
converted_schema["fields"] = dict()
schema = schema if schema else dict()
for column, data_type in schema.items():
fields = converted_schema["fields"]
if "int" in data_type or "long" in data_type or "boolean" in data_type:
Expand Down
14 changes: 7 additions & 7 deletions src/syngen/ml/data_loaders/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _load_data(self) -> pd.DataFrame:
return pdx.from_avro(f)

@staticmethod
def _get_preprocessed_schema(schema: Dict) -> Dict:
def _get_preprocessed_schema(schema: Optional[Dict]) -> Optional[Dict]:
"""
Get the preprocessed schema
"""
Expand All @@ -293,6 +293,7 @@ def _get_preprocessed_schema(schema: Dict) -> Dict:
for field
in schema.get("fields", {})
}
return schema

def load_data(self, **kwargs) -> Tuple[pd.DataFrame, Dict]:
"""
Expand All @@ -301,7 +302,7 @@ def load_data(self, **kwargs) -> Tuple[pd.DataFrame, Dict]:
try:
df = self._load_data()
schema = self.load_schema()
return self._preprocess_schema_and_df(schema, df)
return self._get_schema_and_df(schema, df)
except FileNotFoundError as error:
message = (
f"It seems that the path to the table isn't valid.\n"
Expand All @@ -321,10 +322,9 @@ def _save_data(self, df: pd.DataFrame, schema: Optional[Dict]):
def save_data(self, df: pd.DataFrame, schema: Optional[Dict] = None, **kwargs):
if schema is not None:
logger.trace(f"The data will be saved with the schema: {schema}")
preprocessed_schema = (
self._get_preprocessed_schema(schema) if schema is not None else schema
)
df = AvroConvertor(preprocessed_schema, df).preprocessed_df

preprocessed_schema = self._get_preprocessed_schema(schema)
df = AvroConvertor(preprocessed_schema, df).preprocessed_df
self._save_data(df, schema)

def __load_original_schema(self):
Expand All @@ -351,7 +351,7 @@ def load_schema(self) -> Dict[str, str]:
return self._get_preprocessed_schema(original_schema)

@staticmethod
def _preprocess_schema_and_df(
def _get_schema_and_df(
schema: Dict[str, str], df: pd.DataFrame
) -> Tuple[pd.DataFrame, Dict[str, str]]:
"""
Expand Down
83 changes: 83 additions & 0 deletions src/tests/unit/convertors/test_convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,89 @@ def test_initiate_avro_convertor(rp_logger):
rp_logger.info(SUCCESSFUL_MESSAGE)


def test_initiate_avro_convertor_without_provided_schema(rp_logger):
rp_logger.info("Initiating the instance of the class AvroConvertor without a provided schema")
df, schema = DataLoader(
f"{DIR_NAME}/unit/convertors/fixtures/csv_tables/table_with_diff_data_types.csv"
).load_data()

convertor = AvroConvertor(schema=None, df=df)

assert df.dtypes.to_dict() == {
"employeekey": dtype("int64"),
"parentemployeekey": dtype("float64"),
"parentemployeenationalidalternatekey": dtype("float64"),
"employeenationalidalternatekey": pd.StringDtype(),
"salesterritorykey": dtype("int64"),
"firstname": pd.StringDtype(),
"lastname": pd.StringDtype(),
"middlename": pd.StringDtype(),
"namestyle": dtype("int64"),
"title": pd.StringDtype(),
"hiredate": pd.StringDtype(),
"birthdate": pd.StringDtype(),
"loginid": pd.StringDtype(),
"emailaddress": pd.StringDtype(),
"phone": pd.StringDtype(),
"maritalstatus": pd.StringDtype(),
"emergencycontactname": pd.StringDtype(),
"emergencycontactphone": pd.StringDtype(),
"salariedflag": dtype("int64"),
"gender": pd.StringDtype(),
"payfrequency": dtype("int64"),
"baserate": dtype("float64"),
"vacationhours": dtype("int64"),
"sickleavehours": dtype("int64"),
"currentflag": dtype("int64"),
"salespersonflag": dtype("int64"),
"departmentname": pd.StringDtype(),
"startdate": pd.StringDtype(),
"enddate": pd.StringDtype(),
"status": pd.StringDtype(),
"employeephoto": pd.StringDtype(),
}

assert convertor.converted_schema == {
"fields": {},
"format": "Avro"
}
assert convertor.preprocessed_df.dtypes.to_dict() == {
"employeekey": dtype("int64"),
"parentemployeekey": dtype("float64"),
"parentemployeenationalidalternatekey": dtype("float64"),
"employeenationalidalternatekey": pd.StringDtype(),
"salesterritorykey": dtype("int64"),
"firstname": pd.StringDtype(),
"lastname": pd.StringDtype(),
"middlename": pd.StringDtype(),
"namestyle": dtype("int64"),
"title": pd.StringDtype(),
"hiredate": pd.StringDtype(),
"birthdate": pd.StringDtype(),
"loginid": pd.StringDtype(),
"emailaddress": pd.StringDtype(),
"phone": pd.StringDtype(),
"maritalstatus": pd.StringDtype(),
"emergencycontactname": pd.StringDtype(),
"emergencycontactphone": pd.StringDtype(),
"salariedflag": dtype("int64"),
"gender": pd.StringDtype(),
"payfrequency": dtype("int64"),
"baserate": dtype("float64"),
"vacationhours": dtype("int64"),
"sickleavehours": dtype("int64"),
"currentflag": dtype("int64"),
"salespersonflag": dtype("int64"),
"departmentname": pd.StringDtype(),
"startdate": pd.StringDtype(),
"enddate": pd.StringDtype(),
"status": pd.StringDtype(),
"employeephoto": pd.StringDtype(),
}
pd.testing.assert_series_equal(convertor.preprocessed_df.dtypes, df.dtypes)
rp_logger.info(SUCCESSFUL_MESSAGE)


def test_initiate_avro_convertor_if_schema_contains_unsupported_data_type(caplog, rp_logger):
rp_logger.info(
"Initiating the instance of the class AvroConvertor "
Expand Down
19 changes: 19 additions & 0 deletions src/tests/unit/data_loaders/test_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,25 @@ def test_save_data_in_avro_format(test_avro_path, test_df, test_avro_schema, rp_
rp_logger.info(SUCCESSFUL_MESSAGE)


def test_save_data_in_avro_format_without_provided_schema(
test_avro_path, test_df, rp_logger
):
rp_logger.info("Saving data in avro format locally without a provided schema")
data_loader = DataLoader(test_avro_path)
data_loader.save_data(test_df, schema=None)

assert isinstance(data_loader.file_loader, AvroLoader)
assert os.path.exists(test_avro_path) is True

loaded_df, schema = data_loader.load_data()
pd.testing.assert_frame_equal(loaded_df, test_df)
assert schema == {
"fields": {"gender": "int", "height": "float", "id": "int"},
"format": "Avro",
}
rp_logger.info(SUCCESSFUL_MESSAGE)


def test_load_data_from_table_in_pickle_format(rp_logger):
rp_logger.info("Loading data from local table in pickle format")
data_loader = DataLoader(f"{DIR_NAME}/unit/data_loaders/fixtures/"
Expand Down

0 comments on commit 4425864

Please sign in to comment.