Skip to content

Commit

Permalink
update unit tests in 'tests/unit/data_loaders'
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Jan 3, 2025
1 parent 07574c9 commit df10ae7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.2rc3
0.10.2rc4
18 changes: 8 additions & 10 deletions src/syngen/ml/convertor/convertor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Tuple
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, date

import pandas as pd
import numpy as np
Expand All @@ -14,15 +14,15 @@ class Convertor:
"""
schema: Dict
df: pd.DataFrame
excluded_dtypes: Tuple = (str, bytes, datetime, date)

@staticmethod
def _check_dtype_or_nan(dtypes: Tuple):
def _check_dtype_or_nan(self, dtypes: Tuple):
"""
Check if the value is of the specified data type or 'np.NaN'
Check if the value is of the specified data types or 'np.NaN'
"""
return (
lambda x: isinstance(x, dtypes)
or (not isinstance(x, (str, bytes, datetime)) and np.isnan(x))
or (not isinstance(x, self.excluded_dtypes) and np.isnan(x))
)

def _update_data_types(self, schema: Dict, df: pd.DataFrame):
Expand Down Expand Up @@ -70,8 +70,7 @@ def _set_none_values_to_nan(df: pd.DataFrame):
]
return df

@staticmethod
def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame:
def _cast_values_to_string(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Cast the values contained in columns with the data type 'object'
to 'string'
Expand All @@ -80,7 +79,7 @@ def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame:
for column in df_object_subset:
df[column] = [
i
if not isinstance(i, (str, bytes, datetime)) and np.isnan(i)
if not isinstance(i, self.excluded_dtypes) and np.isnan(i)
else str(i)
for i in df[column]
]
Expand All @@ -95,11 +94,10 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame:
df = self._set_none_values_to_nan(df)
df = self._cast_values_to_string(df)
self._update_data_types(schema, df)
return df
except Exception as e:
logger.error(e)
raise e
else:
return df
else:
return df

Expand Down
19 changes: 19 additions & 0 deletions src/tests/unit/data_loaders/test_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,25 @@ def test_save_data_in_avro_format(test_avro_path, test_df, test_avro_schema, rp_
rp_logger.info(SUCCESSFUL_MESSAGE)


def test_save_data_in_avro_format_without_provided_schema(
test_avro_path, test_df, rp_logger
):
rp_logger.info("Saving data in avro format locally without a provided schema")
data_loader = DataLoader(test_avro_path)
data_loader.save_data(test_df, schema=None)

assert isinstance(data_loader.file_loader, AvroLoader)
assert os.path.exists(test_avro_path) is True

loaded_df, schema = data_loader.load_data()
pd.testing.assert_frame_equal(loaded_df, test_df)
assert schema == {
"fields": {"gender": "int", "height": "float", "id": "int"},
"format": "Avro",
}
rp_logger.info(SUCCESSFUL_MESSAGE)


def test_load_data_from_table_in_pickle_format(rp_logger):
rp_logger.info("Loading data from local table in pickle format")
data_loader = DataLoader(f"{DIR_NAME}/unit/data_loaders/fixtures/"
Expand Down

0 comments on commit df10ae7

Please sign in to comment.