Skip to content

Commit bf1f470

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Separating out get_map_values helper from MapUnitX transform (facebook#3313)
Summary: This commit separates out a `get_map_values` helper function from the `MapUnitX` transform. Differential Revision: D69213291
1 parent 8d08aa3 commit bf1f470

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

Diff for: ax/modelbridge/transforms/map_unit_x.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,7 @@ def __init__(
4141
) -> None:
4242
assert observations is not None, "MapUnitX requires observations"
4343
assert search_space is not None, "MapUnitX requires search space"
44-
# Loop through observation features and identify parameters that
45-
# are not part of the search space. Store all observed values to
46-
# infer bounds
47-
map_values = defaultdict(list)
48-
for obs in observations:
49-
for p in obs.features.parameters:
50-
if p not in search_space.parameters:
51-
map_values[p].append(obs.features.parameters[p])
44+
map_values = get_map_values(search_space, observations)
5245

5346
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
5447
# `typing.List` to avoid runtime subscripting errors.
@@ -81,3 +74,30 @@ def untransform_observation_features(
8174
scale_fac = (u - l) / self.target_range
8275
obsf.parameters[p_name] = scale_fac * (param - self.target_lb) + l
8376
return observation_features
77+
78+
79+
def get_map_values(
80+
search_space: SearchSpace,
81+
observations: list[Observation],
82+
) -> dict[str, list[float]]:
83+
"""Computes a dictionary mapping the name of a map parameter to its associated
84+
progression values, in the same order as they occur in the observations.
85+
86+
Args:
87+
search_space: The search space.
88+
observations: A list of observations associated with the search space.
89+
90+
Returns:
91+
The dictionary mapping the name of a map metric to the associated values,
92+
in the same order they occur in `observations`.
93+
"""
94+
# Loop through observation features and identify parameters that
95+
# are not part of the search space. Store all observed values to
96+
# infer bounds
97+
map_values = defaultdict(list)
98+
for obs in observations:
99+
# if we had access to the original data object, could loop over data.map_keys
100+
for p in obs.features.parameters:
101+
if p not in search_space.parameters:
102+
map_values[p].append(obs.features.parameters[p])
103+
return map_values

0 commit comments

Comments
 (0)