Skip to content

Commit

Permalink
remove rng from EnrichmentSource and pass rng to get_value_for_featur…
Browse files Browse the repository at this point in the history
…e() method
  • Loading branch information
leo-desbureaux-tellae committed Mar 7, 2024
1 parent dfe307f commit 80f5a43
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 13 deletions.
13 changes: 13 additions & 0 deletions bhepop2/enrichment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def _evaluate_feature_on_population(self):
# implement feature evaluation using a dedicated algorithm
raise NotImplementedError

def _get_value_for_feature(self, feature_id):
"""
Get a feature value for the given feature id.
This method is a helper that class self.source.get_value_for_feature
with feature id and self.rng.
:param feature_id:
:return: feature value
"""
return self.source.get_value_for_feature(feature_id, self.rng)

# validation and read

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion bhepop2/enrichment/bhepop2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _draw_feature_value(self, probs):
feature_index = self.rng.choice(feature_indexes, p=filtered_probs)

# get feature value from distribution
value = self.source.get_value_for_feature(feature_index)
value = self._get_value_for_feature(feature_index)

return value

Expand Down
2 changes: 1 addition & 1 deletion bhepop2/enrichment/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _evaluate_feature_on_population(self):

def _draw_feature_value(self):
feature_index = self.rng.integers(len(self.source.feature_values))
return self.source.get_value_for_feature(feature_index)
return self._get_value_for_feature(feature_index)

def _validate_and_process_inputs(self):
assert isinstance(self.source, QuantitativeGlobalDistribution)
7 changes: 2 additions & 5 deletions bhepop2/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def __init__(self, data, name: str = None):

self.name: str = DEFAULT_SOURCE_NAME if name is None else name

# random number generator.
# Set by the SyntheticPopulationEnrichment class is used with one.
self.rng = None

self.data = data

self._feature_values = None
Expand Down Expand Up @@ -77,14 +73,15 @@ def _validate_data(self):
pass

@abstractmethod
def get_value_for_feature(self, feature_index):
def get_value_for_feature(self, feature_index, rng):
"""
Return a feature value for the given feature index.
Generate a singular value from the feature state
corresponding to the given index.
:param feature_index: index of the feature in self.feature_values
:param rng: Numpy random Generator
:return: feature value
"""
Expand Down
7 changes: 4 additions & 3 deletions bhepop2/sources/global_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,23 @@ def _validate_data(self):
assert set(self.data.columns) >= {f"D{i}" for i in range(1, 10)}
assert len(self.data) == 1

def get_value_for_feature(self, feature_index):
def get_value_for_feature(self, feature_index, rng):
"""
Return a value drawn from the interval corresponding to the feature index.
The first interval is defined as [self._abs_minimum, self.feature_values[0]].
and so on. The value is drawn using a uniform rule.
:param feature_index:
:param rng:
:return:
"""

interval_values = [self._abs_minimum] + self.feature_values

print(interval_values)
lower, upper = interval_values[feature_index], interval_values[feature_index + 1]

draw = self.rng.uniform()
draw = rng.uniform()

drawn_feature_value = lower + (upper - lower) * draw

Expand Down
7 changes: 4 additions & 3 deletions bhepop2/sources/marginal_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def compute_feature_prob(self, attribute=ALL_LABEL, modality=ALL_LABEL):

return res

def get_value_for_feature(self, feature_index):
def get_value_for_feature(self, feature_index, rng):
# directly return the stored feature value
return self.feature_values[feature_index]

Expand Down Expand Up @@ -323,22 +323,23 @@ def compute_feature_prob(self, attribute=ALL_LABEL, modality=ALL_LABEL):

return prob_df

def get_value_for_feature(self, feature_index):
def get_value_for_feature(self, feature_index, rng):
"""
Return a value drawn from the interval corresponding to the feature index.
The first interval is defined as [self._abs_minimum, self.feature_values[0]].
and so on. The value is drawn using a uniform rule.
:param feature_index:
:param rng:
:return:
"""

interval_values = [self._abs_minimum] + self.feature_values

lower, upper = interval_values[feature_index], interval_values[feature_index + 1]

draw = self.rng.uniform()
draw = rng.uniform()

drawn_feature_value = lower + (upper - lower) * draw

Expand Down

0 comments on commit 80f5a43

Please sign in to comment.