Skip to content

Commit

Permalink
refactor 'ml/convertor', 'ml/data_loaders', fix the bug of saving the…
Browse files Browse the repository at this point in the history
… data in the '.avro' format without providing the schema
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Jan 2, 2025
1 parent 47ced85 commit bae78e1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.1
0.10.2rc0
26 changes: 21 additions & 5 deletions src/syngen/ml/convertor/convertor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Tuple
from dataclasses import dataclass

import pandas as pd
Expand All @@ -15,7 +15,13 @@ class Convertor:
df: pd.DataFrame

@staticmethod
def _update_data_types(schema: Dict, df: pd.DataFrame):
def _check_dtype_or_nan(dtypes: Tuple):
return (
lambda x: isinstance(x, dtypes)
or (not isinstance(x, (str, bytes)) and np.isnan(x))
)

def _update_data_types(self, schema: Dict, df: pd.DataFrame):
"""
Update data types related to the fetched schema
"""
Expand All @@ -36,6 +42,15 @@ def _update_data_types(schema: Dict, df: pd.DataFrame):
f"isn\'t correct for the column - '{column}' as it's not empty"
)

if not schema.get("fields"):
for column in df.columns:
if df[column].apply(lambda x: isinstance(x, int)).all():
df[column] = df[column].astype(int)
elif df[column].apply(self._check_dtype_or_nan(dtypes=(int, float))).all():
df[column] = df[column].astype(float)
elif df[column].apply(self._check_dtype_or_nan(dtypes=(str, bytes))).all():
df[column] = df[column].astype(pd.StringDtype())

@staticmethod
def _set_none_values_to_nan(df: pd.DataFrame):
"""
Expand All @@ -61,7 +76,7 @@ def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame:
for column in df_object_subset:
df[column] = [
i
if not isinstance(i, str) and np.isnan(i)
if not isinstance(i, (str, bytes)) and np.isnan(i)
else str(i)
for i in df[column]
]
Expand All @@ -73,13 +88,13 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame:
"""
if not df.empty:
try:
df = self._set_none_values_to_nan(df)
df = self._cast_values_to_string(df)
self._update_data_types(schema, df)
except Exception as e:
logger.error(e)
raise e
else:
df = self._set_none_values_to_nan(df)
df = self._cast_values_to_string(df)
return df
else:
return df
Expand Down Expand Up @@ -114,6 +129,7 @@ def _convert_schema(schema) -> Dict:
"""
converted_schema = dict()
converted_schema["fields"] = dict()
schema = schema if schema else dict()
for column, data_type in schema.items():
fields = converted_schema["fields"]
if "int" in data_type or "long" in data_type or "boolean" in data_type:
Expand Down
8 changes: 4 additions & 4 deletions src/syngen/ml/data_loaders/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def _get_preprocessed_schema(schema: Dict) -> Dict:
for field
in schema.get("fields", {})
}
return schema

def load_data(self, **kwargs) -> Tuple[pd.DataFrame, Dict]:
"""
Expand Down Expand Up @@ -321,10 +322,9 @@ def _save_data(self, df: pd.DataFrame, schema: Optional[Dict]):
def save_data(self, df: pd.DataFrame, schema: Optional[Dict] = None, **kwargs):
if schema is not None:
logger.trace(f"The data will be saved with the schema: {schema}")
preprocessed_schema = (
self._get_preprocessed_schema(schema) if schema is not None else schema
)
df = AvroConvertor(preprocessed_schema, df).preprocessed_df

preprocessed_schema = self._get_preprocessed_schema(schema)
df = AvroConvertor(preprocessed_schema, df).preprocessed_df
self._save_data(df, schema)

def __load_original_schema(self):
Expand Down

0 comments on commit bae78e1

Please sign in to comment.