From 80f5a437a3f9a499ce68f92942ffcac95f396364 Mon Sep 17 00:00:00 2001 From: leo-desbureaux-tellae Date: Thu, 7 Mar 2024 16:21:53 +0100 Subject: [PATCH] remove rng from EnrichmentSource and pass rng to get_value_for_feature() method --- bhepop2/enrichment/base.py | 13 +++++++++++++ bhepop2/enrichment/bhepop2.py | 2 +- bhepop2/enrichment/uniform.py | 2 +- bhepop2/sources/base.py | 7 ++----- bhepop2/sources/global_distribution.py | 7 ++++--- bhepop2/sources/marginal_distributions.py | 7 ++++--- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/bhepop2/enrichment/base.py b/bhepop2/enrichment/base.py index 3d62daa..b8c2464 100644 --- a/bhepop2/enrichment/base.py +++ b/bhepop2/enrichment/base.py @@ -84,6 +84,19 @@ def _evaluate_feature_on_population(self): # implement feature evaluation using a dedicated algorithm raise NotImplementedError + def _get_value_for_feature(self, feature_id): + """ + Get a feature value for the given feature id. + + This method is a helper that class self.source.get_value_for_feature + with feature id and self.rng. + + :param feature_id: + + :return: feature value + """ + return self.source.get_value_for_feature(feature_id, self.rng) + # validation and read @abstractmethod diff --git a/bhepop2/enrichment/bhepop2.py b/bhepop2/enrichment/bhepop2.py index 04d2b96..05c9543 100644 --- a/bhepop2/enrichment/bhepop2.py +++ b/bhepop2/enrichment/bhepop2.py @@ -167,7 +167,7 @@ def _draw_feature_value(self, probs): feature_index = self.rng.choice(feature_indexes, p=filtered_probs) # get feature value from distribution - value = self.source.get_value_for_feature(feature_index) + value = self._get_value_for_feature(feature_index) return value diff --git a/bhepop2/enrichment/uniform.py b/bhepop2/enrichment/uniform.py index b2a464c..c6a90a0 100644 --- a/bhepop2/enrichment/uniform.py +++ b/bhepop2/enrichment/uniform.py @@ -40,7 +40,7 @@ def _evaluate_feature_on_population(self): def _draw_feature_value(self): feature_index = self.rng.integers(len(self.source.feature_values)) - return self.source.get_value_for_feature(feature_index) + return self._get_value_for_feature(feature_index) def _validate_and_process_inputs(self): assert isinstance(self.source, QuantitativeGlobalDistribution) diff --git a/bhepop2/sources/base.py b/bhepop2/sources/base.py index 04e6a6b..ae0cb18 100644 --- a/bhepop2/sources/base.py +++ b/bhepop2/sources/base.py @@ -31,10 +31,6 @@ def __init__(self, data, name: str = None): self.name: str = DEFAULT_SOURCE_NAME if name is None else name - # random number generator. - # Set by the SyntheticPopulationEnrichment class is used with one. - self.rng = None - self.data = data self._feature_values = None @@ -77,7 +73,7 @@ def _validate_data(self): pass @abstractmethod - def get_value_for_feature(self, feature_index): + def get_value_for_feature(self, feature_index, rng): """ Return a feature value for the given feature index. @@ -85,6 +81,7 @@ def get_value_for_feature(self, feature_index): corresponding to the given index. :param feature_index: index of the feature in self.feature_values + :param rng: Numpy random Generator :return: feature value """ diff --git a/bhepop2/sources/global_distribution.py b/bhepop2/sources/global_distribution.py index 901b857..0171220 100644 --- a/bhepop2/sources/global_distribution.py +++ b/bhepop2/sources/global_distribution.py @@ -44,7 +44,7 @@ def _validate_data(self): assert set(self.data.columns) >= {f"D{i}" for i in range(1, 10)} assert len(self.data) == 1 - def get_value_for_feature(self, feature_index): + def get_value_for_feature(self, feature_index, rng): """ Return a value drawn from the interval corresponding to the feature index. @@ -52,14 +52,15 @@ def get_value_for_feature(self, feature_index): and so on. The value is drawn using a uniform rule. :param feature_index: + :param rng: :return: """ interval_values = [self._abs_minimum] + self.feature_values - + print(interval_values) lower, upper = interval_values[feature_index], interval_values[feature_index + 1] - draw = self.rng.uniform() + draw = rng.uniform() drawn_feature_value = lower + (upper - lower) * draw diff --git a/bhepop2/sources/marginal_distributions.py b/bhepop2/sources/marginal_distributions.py index f1b3e00..1301bc4 100644 --- a/bhepop2/sources/marginal_distributions.py +++ b/bhepop2/sources/marginal_distributions.py @@ -196,7 +196,7 @@ def compute_feature_prob(self, attribute=ALL_LABEL, modality=ALL_LABEL): return res - def get_value_for_feature(self, feature_index): + def get_value_for_feature(self, feature_index, rng): # directly return the stored feature value return self.feature_values[feature_index] @@ -323,7 +323,7 @@ def compute_feature_prob(self, attribute=ALL_LABEL, modality=ALL_LABEL): return prob_df - def get_value_for_feature(self, feature_index): + def get_value_for_feature(self, feature_index, rng): """ Return a value drawn from the interval corresponding to the feature index. @@ -331,6 +331,7 @@ def get_value_for_feature(self, feature_index): and so on. The value is drawn using a uniform rule. :param feature_index: + :param rng: :return: """ @@ -338,7 +339,7 @@ def get_value_for_feature(self, feature_index): lower, upper = interval_values[feature_index], interval_values[feature_index + 1] - draw = self.rng.uniform() + draw = rng.uniform() drawn_feature_value = lower + (upper - lower) * draw