diff --git a/bhepop2/enrichment/base.py b/bhepop2/enrichment/base.py index b1092dd..948fdc0 100644 --- a/bhepop2/enrichment/base.py +++ b/bhepop2/enrichment/base.py @@ -7,6 +7,8 @@ from abc import ABC, abstractmethod import pandas as pd from numpy import random +from bhepop2.sources.base import EnrichmentSource +from bhepop2.utils import PopulationValidationError, SourceValidationError class SyntheticPopulationEnrichment(ABC, Bhepop2Logger): @@ -22,6 +24,8 @@ class SyntheticPopulationEnrichment(ABC, Bhepop2Logger): are core to the SyntheticPopulationEnrichment classes. """ + _required_source_class = EnrichmentSource + def __init__( self, population: pd.DataFrame, @@ -33,9 +37,16 @@ def __init__( Bhepop2Logger.__init__(self) # original population DataFrame, to be enriched + if population.empty: # pragma: no cover + raise PopulationValidationError("Population to enrich is empty") self.population: pd.DataFrame = population.copy() # enrichment data source + if not isinstance(source, self._required_source_class): + raise SourceValidationError( + f"{self.__class__.__name__} requires an instance " + f"of {self._required_source_class} as a source" + ) self.source = source # name of the added column containing the new feature values @@ -51,6 +62,7 @@ def __init__( # input validation self.log("Input data validation and preprocessing", lg.INFO) self._validate_and_process_inputs() + self.source.usable_with_population(self.population) # feature assignment @@ -97,7 +109,6 @@ def _get_value_for_feature(self, feature_id): # validation and read - @abstractmethod def _validate_and_process_inputs(self): """ Validate and process the provided enrichment inputs. diff --git a/bhepop2/enrichment/bhepop2.py b/bhepop2/enrichment/bhepop2.py index 05c9543..011fe6a 100644 --- a/bhepop2/enrichment/bhepop2.py +++ b/bhepop2/enrichment/bhepop2.py @@ -100,25 +100,6 @@ def modalities(self): """ return self.source.modalities - def _validate_and_process_inputs(self): - """ - Validate the provided inputs and set the relevant fields. - - Since Bhepop2 uses marginal distributions to enrich the population, - we ensure that: - - * the selected attributes are present in the population - * the population attributes take values in the modalities corresponding to this attribute - """ - - assert isinstance( - self.source, MarginalDistributions - ), "Bhepop2Enrichment needs a MarginalDistributions source" - - self.log("Setup population data") - - functions.validate_population(self.population, self.modalities) - def _evaluate_feature_on_population(self): """ Assign feature values to the population individuals using the algorithm results. @@ -134,6 +115,7 @@ def _evaluate_feature_on_population(self): res = self._get_feature_probs() # associate each individual to a crossed modality + # TODO : fails if "index" column already exists self.crossed_modalities_frequencies["index"] = self.crossed_modalities_frequencies.index merge = self.population.merge( self.crossed_modalities_frequencies, @@ -356,7 +338,7 @@ def f0(x): # add a feature for all modalities except one for all variables for attribute in attributes: for modality in self.modalities[attribute][:-1]: - features.append(functions.modality_feature(attribute, modality)) + features.append(modality_feature(attribute, modality)) nb_lines = len(features) nb_cols = len(samplespace_reducted) @@ -369,3 +351,19 @@ def f0(x): crossed_modalities_matrix[i, j] = f_i_x return crossed_modalities_matrix + + +def modality_feature(attribute, modality) -> callable: + """ + Create a function that checks if a sample belongs to the given attribute and modality. + + :param attribute: attribute value + :param modality: modality value + + :return: feature checking function + """ + + def feature(x): + return x[attribute] == modality + + return feature diff --git a/bhepop2/enrichment/uniform.py b/bhepop2/enrichment/uniform.py index c6a90a0..eb1be72 100644 --- a/bhepop2/enrichment/uniform.py +++ b/bhepop2/enrichment/uniform.py @@ -7,7 +7,6 @@ """ from .base import SyntheticPopulationEnrichment -from bhepop2.sources.global_distribution import QuantitativeGlobalDistribution class SimpleUniformEnrichment(SyntheticPopulationEnrichment): @@ -41,6 +40,3 @@ def _evaluate_feature_on_population(self): def _draw_feature_value(self): feature_index = self.rng.integers(len(self.source.feature_values)) return self._get_value_for_feature(feature_index) - - def _validate_and_process_inputs(self): - assert isinstance(self.source, QuantitativeGlobalDistribution) diff --git a/bhepop2/functions.py b/bhepop2/functions.py index dc53285..7e59cba 100644 --- a/bhepop2/functions.py +++ b/bhepop2/functions.py @@ -17,54 +17,9 @@ def get_attributes(modalities: dict) -> list: return list(modalities.keys()) -def modality_feature(attribute, modality) -> callable: - """ - Create a function that checks if a sample belongs to the given attribute and modality. - - :param attribute: attribute value - :param modality: modality value - - :return: feature checking function - """ - - def feature(x): - return x[attribute] == modality - - return feature - - # distribution functions -def validate_distributions(distributions: pd.DataFrame, attribute_selection, mode): - """ - Validate the format and contents of the given distribution. - - :param distributions: distribution DataFrame - :param attribute_selection: list of attributes to keep in the distribution, or None - :param mode: "qualitative" or "quantitative" - :raises: AssertionError - """ - - assert not distributions.empty, "Empty distributions table provided" - - if mode == "quantitative": - # we could validate the distributions columns (positive, monotony ?) - assert {*["D{}".format(i) for i in range(1, 10)], "attribute", "modality"} <= set( - distributions.columns - ), "Distributions table lacks the required columns" - elif mode == "qualitative": - assert "attribute" in distributions.columns and "modality" in distributions.columns - else: - raise ValueError(f"Unknown mode '{mode}'") - - if attribute_selection is not None: - # check that the distributions contain the selected attributes - assert set(attribute_selection + ["all"]) <= set( - distributions["attribute"] - ), "Distributions table does not include selected attributes" - - def filter_distributions_and_infer_modalities(distributions: pd.DataFrame, attribute_selection): """ Filter distributions table with attribute selection and infer modalities. @@ -164,8 +119,6 @@ def get_feature_from_qualitative_distribution(distribution: pd.DataFrame): features.remove("attribute") features.remove("modality") - assert (distribution[features].apply(lambda row: np.isclose(row.sum(), 1), axis=1)).all() - return features @@ -226,28 +179,6 @@ def interpolate_feature_prob(feature_value: float, distribution: list): # population functions -def validate_population(population: pd.DataFrame, modalities: dict): - """ - Validate the format and contents of the given population. - - Check that the population is compatible with the chosen modalities. - - :param population: distribution DataFrame - :param modalities: - :raises: AssertionError - """ - - attributes = get_attributes(modalities) - - assert {*attributes} <= set(population.columns) - - for attribute in attributes: - assert population[attribute].isin(modalities[attribute]).all(), ( - f"Population validation: one of the modality values was not " - f"found in distributions for the attribute '{attribute}'" - ) - - def compute_crossed_modalities_frequencies( population: pd.DataFrame, modalities: dict ) -> pd.DataFrame: diff --git a/bhepop2/optim.py b/bhepop2/optim.py index 427b51c..93279bd 100644 --- a/bhepop2/optim.py +++ b/bhepop2/optim.py @@ -2,6 +2,7 @@ from numpy.linalg import norm from .utils import log import logging as lg +import warnings def minxent_gradient( @@ -48,7 +49,24 @@ def minxent_gradient( while not (did_ascent and did_descent): lambda_new = lambda_old - alpha * f_old - lambda0 = np.log(q.T.dot(np.exp(-matrix.T.dot(lambda_new)))) + # exp can sometime exceed float64 max size + # for now, we just catch the warning and + with warnings.catch_warnings(record=True) as w: + lambda0 = np.log(q.T.dot(np.exp(-matrix.T.dot(lambda_new)))) + if len(w) > 0: + # in python 3.11 : catch_warnings(category=RuntimeWarning) + if issubclass(w[0].category, RuntimeWarning): + log( + "Leaving gradient descent due to exp exceeding float64 max size", + lg.DEBUG, + ) + break + else: + log( + f"This warning was caught during gradient descent: {w[0].category.__name__}('{w[0].message}')", + lg.WARN, + ) + level_objective_new = lambda0 + np.sum(lambda_new * eta) # level_objective_new = objective(q=q, G=G, eta=eta, lambda_=lambda_new) @@ -85,3 +103,43 @@ def minxent_gradient( test_pierreolivier = matrix.dot(pk) - eta # log("test PO : " + str(test_pierreolivier/eta), 10) return pk.tolist(), lambda_ + + +# while not (did_ascent and did_descent): +# log( +# "not did ascent and did descent", +# lg.DEBUG, +# ) +# lambda_new = lambda_old - alpha * f_old +# # exp can sometime exceed float64 max size +# # for now, we just catch the warning and +# +# converged = False +# while not converged: +# with warnings.catch_warnings(record=True) as w: +# +# lambda0 = np.log(q.T.dot(np.exp(-matrix.T.dot(lambda_new)))) +# +# if len(w) > 0: +# log("", lg.DEBUG) +# log( +# lambda_new, +# lg.DEBUG, +# ) +# log( +# lambda0, +# lg.DEBUG, +# ) +# # in python 3.11 : catch_warnings(category=RuntimeWarning) +# if issubclass(w[0].category, RuntimeWarning): +# alpha *= common_ratio_descending +# lambda_new = lambda_old - alpha * f_old +# +# else: +# log( +# f"This warning was caught during gradient descent: {w[0].category.__name__}('{w[0].message}')", +# lg.WARN, +# ) +# exit(0) +# else: +# converged = True diff --git a/bhepop2/sources/base.py b/bhepop2/sources/base.py index ae0cb18..98a3d62 100644 --- a/bhepop2/sources/base.py +++ b/bhepop2/sources/base.py @@ -68,10 +68,20 @@ def _validate_data(self): Raise a ValueError if data is invalid. - :raises: ValueError + :raises: SourceValidationError """ pass + def usable_with_population(self, population): + """ + Validate that this source is compatible with the given population. + + Raise a PopulationValidationError if an incompatibility is found. + + :param population: population DataFrame + :raises PopulationValidationError: + """ + @abstractmethod def get_value_for_feature(self, feature_index, rng): """ diff --git a/bhepop2/sources/global_distribution.py b/bhepop2/sources/global_distribution.py index 17792af..4462044 100644 --- a/bhepop2/sources/global_distribution.py +++ b/bhepop2/sources/global_distribution.py @@ -6,6 +6,7 @@ """ from .base import EnrichmentSource, QuantitativeAttributes +from bhepop2.utils import SourceValidationError import numpy as np @@ -41,8 +42,10 @@ def _validate_data(self): """ Check that the deciles columns are present and that length is 1. """ - assert set(self.data.columns) >= {f"D{i}" for i in range(1, 10)} - assert len(self.data) == 1 + if not set(self.data.columns) >= {f"D{i}" for i in range(1, 10)}: + raise SourceValidationError("Distribution table lacks the required columns") + if len(self.data) != 1: + raise SourceValidationError("Distribution table is expected to have exactly one row") def get_value_for_feature(self, feature_index, rng): """ diff --git a/bhepop2/sources/marginal_distributions.py b/bhepop2/sources/marginal_distributions.py index 1301bc4..f0db425 100644 --- a/bhepop2/sources/marginal_distributions.py +++ b/bhepop2/sources/marginal_distributions.py @@ -7,9 +7,11 @@ from .base import EnrichmentSource, QuantitativeAttributes from bhepop2 import functions +from bhepop2.utils import PopulationValidationError, SourceValidationError from bhepop2.analysis import QuantitativeAnalysis, QualitativeAnalysis import pandas as pd +import numpy as np from abc import abstractmethod #: attribute and modality corresponding to the global distribution @@ -67,6 +69,25 @@ def __init__(self, data, name=None, attribute_selection: list = None): super().__init__(data, name=name) def _validate_data(self): + # check "attribute" and "modality" columns existence + if not ("attribute" in self.data.columns and "modality" in self.data.columns): + raise SourceValidationError("Missing 'attribute' or 'modality' column") + + # check that the ALL_LABEL attribute is in the columns + if ALL_LABEL not in list(self.data["attribute"]): + raise SourceValidationError( + f"Missing required '{ALL_LABEL}' attribute, " + f"used to describe the global population" + ) + + # check that provided attribute selection exists in distributions + if self.attribute_selection is not None: + if not set(self.attribute_selection) <= set(self.data["attribute"]): + raise SourceValidationError( + f"Source distributions table does not " + f"include selected attributes {self.attribute_selection}" + ) + # quantitative or qualitative check self._validate_data_type() @@ -76,10 +97,34 @@ def _validate_data(self): self.data, self.attribute_selection ) - # check that there are modalities at the end - assert ( - len(self.modalities.keys()) > 0 - ), "No attributes found in distributions for enriching population" + def usable_with_population(self, population): + """ + Check that the population attributes are compatible with the source. + + Check that the source attributes are present in the population. + Check that the population values of each attribute are in the source distributions. + + :param population: population DataFrame + :raises: PopulationValidationError + """ + + attributes = list(self.modalities.keys()) + + if not {*attributes} <= set(population.columns): + raise PopulationValidationError( + "Some of the source attributes are missing from the population columns.\n\n" + f"Source attributes: {attributes}\n" + f"Population columns: {population.columns}" + ) + + for attribute in attributes: + if not population[attribute].isin(self.modalities[attribute]).all(): + raise PopulationValidationError( + f"Population validation: one of the values " + f"for the '{attribute}' attribute was not found in source distributions.\n" + f"Population values: {population[attribute].unique()}\n" + f"Source values: {self.modalities[attribute]}" + ) @abstractmethod def _validate_data_type(self): @@ -183,8 +228,9 @@ def _evaluate_feature_values(self): return functions.get_feature_from_qualitative_distribution(self.data) def _validate_data_type(self): - # TODO : test that self._abs_minimum is inferior to all distribution values - functions.validate_distributions(self.data, self.attribute_selection, "qualitative") + features = functions.get_feature_from_qualitative_distribution(self.data) + if not (self.data[features].apply(lambda row: np.isclose(row.sum(), 1), axis=1)).all(): + raise SourceValidationError("Some distributions don't sum to 1") def compute_feature_prob(self, attribute=ALL_LABEL, modality=ALL_LABEL): # get distribution for the given modality @@ -207,7 +253,7 @@ def compare_with_populations(self, populations, feature_name, **kwargs): feature_column=feature_name, distributions=self.data, distributions_name=self.name, - **kwargs + **kwargs, ) @@ -299,7 +345,14 @@ def _evaluate_feature_values(self): return functions.compute_feature_values(self.data, self._relative_maximum, self._delta_min) def _validate_data_type(self): - functions.validate_distributions(self.data, self.attribute_selection, "quantitative") + # TODO : test that self._abs_minimum is inferior to all distribution values + required_columns = ["attribute", "modality"] + ["D{}".format(i) for i in range(1, 10)] + if not {*required_columns} <= set(self.data.columns): + raise SourceValidationError( + f"Distributions table lacks the required columns: {required_columns}" + ) + + # we could validate the distributions columns (positive, monotony ?) def compute_feature_prob(self, attribute=ALL_LABEL, modality=ALL_LABEL): # get distribution for the given modality @@ -334,6 +387,7 @@ def get_value_for_feature(self, feature_index, rng): :param rng: :return: """ + self.log(feature_index) interval_values = [self._abs_minimum] + self.feature_values @@ -342,7 +396,7 @@ def get_value_for_feature(self, feature_index, rng): draw = rng.uniform() drawn_feature_value = lower + (upper - lower) * draw - + self.log(drawn_feature_value) return drawn_feature_value def compare_with_populations(self, populations, feature_name, **kwargs): @@ -352,5 +406,5 @@ def compare_with_populations(self, populations, feature_name, **kwargs): feature_column=feature_name, distributions=self.data, distributions_name=self.name, - **kwargs + **kwargs, ) diff --git a/bhepop2/tools.py b/bhepop2/tools.py index a8f3d2c..9680955 100644 --- a/bhepop2/tools.py +++ b/bhepop2/tools.py @@ -173,7 +173,7 @@ def add_household_type_attribute( assert combined[household_id].is_unique population = population.merge(combined, how="left", on=household_id) - population[column_name].fillna("complex_hh", inplace=True) + population[column_name] = population[column_name].fillna("complex_hh") return population @@ -257,7 +257,7 @@ def read_filosofi(filepath: str, year: str, attributes: list, communes=None): sheet_list = sheet_list + [x["sheet"] for x in attribute["modalities"]] # read Filosofi excel file - filosofi_sheets = read_filosofi_excel(filepath, sheet_list) + filosofi_sheets = pd.read_excel(filepath, sheet_name=sheet_list, skiprows=5) # fetch distributions for the given attributes distributions = read_filosofi_attributes(filosofi_sheets, year, attributes, communes) @@ -265,19 +265,6 @@ def read_filosofi(filepath: str, year: str, attributes: list, communes=None): return distributions -def read_filosofi_excel(filepath: str, sheet_list: list): - """ - Read list of sheets from Filosofi excel file. - - :param filepath: path to Filosofi excel file (DISP_COM) - :param sheet_list: list of sheets to be read - - :return: DataFrame indexed by sheet - """ - excel_df = pd.read_excel(filepath, sheet_name=sheet_list, skiprows=5) - return excel_df - - def read_filosofi_attributes(filosofi_sheets, year, attributes: list, communes=None): """ Read distributions from list of attributes and their modalities in filosofi sheets. diff --git a/bhepop2/utils.py b/bhepop2/utils.py index 7b944ff..2c33606 100644 --- a/bhepop2/utils.py +++ b/bhepop2/utils.py @@ -1,9 +1,25 @@ """ -Utility functions +Utility classes, functions and constants. """ import logging as lg + +# bhepop2 exceptions + + +class PopulationValidationError(Exception): + """ + Raised when a population fails validation. + """ + + +class SourceValidationError(Exception): + """ + Raised when an enrichment source fails validation. + """ + + # log utils (see logging library) #: logging level diff --git a/requirements.txt b/requirements.txt index b26909d..968fcaf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -numpy>=1.26.0 -pandas>=2.1.0 -xlrd>=2.0.0 -jsonschema>=4.20.0 -scipy>=1.11.0 -scikit-learn>=1.3.0 -plotly>=5.18.0 \ No newline at end of file +numpy==1.23.5 +pandas==1.5.3 +xlrd==2.0.1 +scipy==1.10.1 +scikit-learn==1.2.2 +plotly \ No newline at end of file diff --git a/tests/enrichment/test_base.py b/tests/enrichment/test_base.py index 8d02dfd..2893965 100644 --- a/tests/enrichment/test_base.py +++ b/tests/enrichment/test_base.py @@ -1,4 +1,7 @@ +import pandas as pd + from bhepop2.enrichment import Bhepop2Enrichment +from bhepop2.utils import SourceValidationError import pytest @@ -47,6 +50,13 @@ def test_init( # check validation call was made enrich_class._validate_and_process_inputs.assert_called_once_with() + def test_source_class_validation(self): + with pytest.raises(SourceValidationError): + Bhepop2Enrichment( + pd.DataFrame({"id": [1]}), + None, + ) + def test_feature_exists_error( self, synthetic_population_nantes, quantitative_marginal_distributions ): diff --git a/tests/sources/test_global_distribution.py b/tests/sources/test_global_distribution.py index 8291ef1..9627a19 100644 --- a/tests/sources/test_global_distribution.py +++ b/tests/sources/test_global_distribution.py @@ -1,5 +1,5 @@ from bhepop2.sources.global_distribution import QuantitativeGlobalDistribution - +from bhepop2.utils import SourceValidationError import pytest from numpy.random import default_rng @@ -31,11 +31,11 @@ def test_init(self, filosofi_global_distribution_nantes, test_parameters): def test_validate_data(self, filosofi_global_distribution_nantes): # DataFrame with missing decile should raise an error - with pytest.raises(AssertionError): + with pytest.raises(SourceValidationError): QuantitativeGlobalDistribution(filosofi_global_distribution_nantes.drop(["D5"], axis=1)) # empty DataFrame should raise an error - with pytest.raises(AssertionError): + with pytest.raises(SourceValidationError): QuantitativeGlobalDistribution( filosofi_global_distribution_nantes.drop( [filosofi_global_distribution_nantes.index[0]], axis=0 diff --git a/tests/sources/test_marginal_distributions.py b/tests/sources/test_marginal_distributions.py index dbca263..f97d401 100644 --- a/tests/sources/test_marginal_distributions.py +++ b/tests/sources/test_marginal_distributions.py @@ -6,6 +6,7 @@ ) from bhepop2.analysis import QuantitativeAnalysis, QualitativeAnalysis from bhepop2.functions import build_cross_table +from bhepop2.utils import PopulationValidationError from numpy.random import default_rng import pytest @@ -13,6 +14,12 @@ class TestMarginalDistributions: + @pytest.fixture(scope="class") + def source_example(self, filosofi_distributions_nantes): + return QuantitativeMarginalDistributions( + filosofi_distributions_nantes, attribute_selection=["ownership", "age"], delta_min=1000 + ) + def test_init(self, filosofi_distributions_nantes, test_modalities): attribute_selection = list(test_modalities.keys()) @@ -25,10 +32,17 @@ def test_init(self, filosofi_distributions_nantes, test_modalities): assert source.attribute_selection == attribute_selection assert source.modalities == test_modalities - def test_get_modality_distribution(self, filosofi_distributions_nantes): - source = QuantitativeMarginalDistributions(filosofi_distributions_nantes) + def test_usable_with_population(self, source_example): + with pytest.raises(PopulationValidationError): + source_example.usable_with_population( + pd.DataFrame({"ownership": ["Tenant"], "age": ["UNKNOWN"]}) + ) + + with pytest.raises(PopulationValidationError): + source_example.usable_with_population(pd.DataFrame({"ownership": ["Tenant"]})) - modality_distribution = source.get_modality_distribution("ownership", "Tenant") + def test_get_modality_distribution(self, source_example): + modality_distribution = source_example.get_modality_distribution("ownership", "Tenant") assert len(modality_distribution) == 1 assert modality_distribution["attribute"].iloc[0] == "ownership" diff --git a/tests/test_functions.py b/tests/test_functions.py index 078b81b..f1495ec 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -14,22 +14,6 @@ def test_get_attributes(test_modalities): assert get_attributes(test_modalities) == ["ownership", "age", "size", "family_comp"] -def test_modality_feature(): - """ - Test that the result is a function with expected behaviour. - """ - - attribute = "my_attribute" - modality_0 = "modality_0" - modality_1 = "modality_1" - - feature = modality_feature(attribute, modality_0) - - assert callable(feature) - assert feature({attribute: modality_0}) - assert not feature({attribute: modality_1}) - - def test_compute_crossed_modalities_frequencies( synthetic_population_nantes, test_modalities, test_attributes ): @@ -88,11 +72,6 @@ def test_get_feature_from_qualitative_distribution(): assert features == ["0voit", "1voit", "2voit", "3voit"] - distribution.loc[0, "2voit"] = 0.5 - - with pytest.raises(AssertionError): - get_feature_from_qualitative_distribution(distribution) - def test_interpolate_feature_prob(): """