From bae78e16c88ba8e9d93fec54f4450c36ff0eb05b Mon Sep 17 00:00:00 2001 From: Hanna Imshenetska Date: Thu, 2 Jan 2025 16:31:34 +0000 Subject: [PATCH] refactor 'ml/convertor', 'ml/data_loaders', fix the bug of saving the data in the '.avro' format without providing the schema --- src/syngen/VERSION | 2 +- src/syngen/ml/convertor/convertor.py | 26 +++++++++++++++++----- src/syngen/ml/data_loaders/data_loaders.py | 8 +++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/syngen/VERSION b/src/syngen/VERSION index 57121573..ba8bda59 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.10.1 +0.10.2rc0 diff --git a/src/syngen/ml/convertor/convertor.py b/src/syngen/ml/convertor/convertor.py index 8f8de4ae..97a3d126 100644 --- a/src/syngen/ml/convertor/convertor.py +++ b/src/syngen/ml/convertor/convertor.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Tuple from dataclasses import dataclass import pandas as pd @@ -15,7 +15,13 @@ class Convertor: df: pd.DataFrame @staticmethod - def _update_data_types(schema: Dict, df: pd.DataFrame): + def _check_dtype_or_nan(dtypes: Tuple): + return ( + lambda x: isinstance(x, dtypes) + or (not isinstance(x, (str, bytes)) and np.isnan(x)) + ) + + def _update_data_types(self, schema: Dict, df: pd.DataFrame): """ Update data types related to the fetched schema """ @@ -36,6 +42,15 @@ def _update_data_types(schema: Dict, df: pd.DataFrame): f"isn\'t correct for the column - '{column}' as it's not empty" ) + if not schema.get("fields"): + for column in df.columns: + if df[column].apply(lambda x: isinstance(x, int)).all(): + df[column] = df[column].astype(int) + elif df[column].apply(self._check_dtype_or_nan(dtypes=(int, float))).all(): + df[column] = df[column].astype(float) + elif df[column].apply(self._check_dtype_or_nan(dtypes=(str, bytes))).all(): + df[column] = df[column].astype(pd.StringDtype()) + @staticmethod def _set_none_values_to_nan(df: pd.DataFrame): """ @@ -61,7 +76,7 @@ def _cast_values_to_string(df: pd.DataFrame) -> pd.DataFrame: for column in df_object_subset: df[column] = [ i - if not isinstance(i, str) and np.isnan(i) + if not isinstance(i, (str, bytes)) and np.isnan(i) else str(i) for i in df[column] ] @@ -73,13 +88,13 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame: """ if not df.empty: try: + df = self._set_none_values_to_nan(df) + df = self._cast_values_to_string(df) self._update_data_types(schema, df) except Exception as e: logger.error(e) raise e else: - df = self._set_none_values_to_nan(df) - df = self._cast_values_to_string(df) return df else: return df @@ -114,6 +129,7 @@ def _convert_schema(schema) -> Dict: """ converted_schema = dict() converted_schema["fields"] = dict() + schema = schema if schema else dict() for column, data_type in schema.items(): fields = converted_schema["fields"] if "int" in data_type or "long" in data_type or "boolean" in data_type: diff --git a/src/syngen/ml/data_loaders/data_loaders.py b/src/syngen/ml/data_loaders/data_loaders.py index 74f7128f..ad1d7a36 100644 --- a/src/syngen/ml/data_loaders/data_loaders.py +++ b/src/syngen/ml/data_loaders/data_loaders.py @@ -293,6 +293,7 @@ def _get_preprocessed_schema(schema: Dict) -> Dict: for field in schema.get("fields", {}) } + return schema def load_data(self, **kwargs) -> Tuple[pd.DataFrame, Dict]: """ @@ -321,10 +322,9 @@ def _save_data(self, df: pd.DataFrame, schema: Optional[Dict]): def save_data(self, df: pd.DataFrame, schema: Optional[Dict] = None, **kwargs): if schema is not None: logger.trace(f"The data will be saved with the schema: {schema}") - preprocessed_schema = ( - self._get_preprocessed_schema(schema) if schema is not None else schema - ) - df = AvroConvertor(preprocessed_schema, df).preprocessed_df + + preprocessed_schema = self._get_preprocessed_schema(schema) + df = AvroConvertor(preprocessed_schema, df).preprocessed_df self._save_data(df, schema) def __load_original_schema(self):