Skip to content

Commit

Permalink
Save full[er] GeneratorRuns (#2515)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2515

Reviewed By: eonofrey

Differential Revision: D58464441
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jun 13, 2024
1 parent 6a30c90 commit 6226cc3
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions ax/core/generator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,7 @@ def __init__(
"one is provided."
)
for arm, weight in zip(arms, weights):
existing_cw = self._arm_weight_table.get(arm.signature)
if existing_cw:
self._arm_weight_table[arm.signature] = ArmWeight(
arm=arm, weight=existing_cw.weight + weight
)
else:
self._arm_weight_table[arm.signature] = ArmWeight(
arm=arm, weight=weight
)
self.add_arm(arm=arm, weight=weight)

self._generator_run_type: Optional[str] = type
self._time_created: datetime = datetime.now()
Expand Down Expand Up @@ -394,6 +386,22 @@ def clone(self) -> GeneratorRun:
)
return generator_run

def add_arm(self, arm: Arm, weight: float = 1.0) -> None:
"""Adds an arm to this generator run. This should not be used to
mutate generator runs that are attached to trials.
Args:
arm: The arm to add.
weight: The weight to associate with the arm.
"""
existing_cw = self._arm_weight_table.get(arm.signature)
if existing_cw:
self._arm_weight_table[arm.signature] = ArmWeight(
arm=arm, weight=existing_cw.weight + weight
)
else:
self._arm_weight_table[arm.signature] = ArmWeight(arm=arm, weight=weight)

def __repr__(self) -> str:
"""String representation of a GeneratorRun."""
class_name = self.__class__.__name__
Expand Down

0 comments on commit 6226cc3

Please sign in to comment.