Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docstring for GenerationNode.gen & fit #2245

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,15 @@ def fit(
the model kwargs set on each corresponding model spec and the kwargs
passed to this method.

NOTE: Local kwargs take precedence over the ones stored in
``ModelSpec.model_kwargs``.
Args:
experiment: The experiment to fit the model to.
data: The experiment data used to fit the model.
search_space: An optional overwrite for the experiment search space.
optimization_config: An optional overwrite for the experiment
optimization config.
kwargs: Additional keyword arguments to pass to the model's
``fit`` method. NOTE: Local kwargs take precedence over the ones
stored in ``ModelSpec.model_kwargs``.
"""
self._model_spec_to_gen_from = None
for model_spec in self.model_specs:
Expand All @@ -283,21 +290,27 @@ def gen(
alongside any kwargs passed in to this function (with local kwargs)
taking precedent.

NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
model-determined number of arms. In that case this method will also output
a generator run with number of arms that may differ from ``n``.

Args:
n: Optional nteger representing how many arms should be in the generator
n: Optional integer representing how many arms should be in the generator
run produced by this method. When this is ``None``, ``n`` will be
determined by the ``ModelSpec`` that we are generating from.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
max_gen_draws_for_deduplication: TODO
max_gen_draws_for_deduplication: Maximum number of attempts for generating
new candidates without duplicates. If non-duplicate candidates are not
generated with these attempts, a ``GenerationStrategyRepeatedPoints``
exception will be raised.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
these override any pre-specified in ``ModelSpec.model_gen_kwargs``.

NOTE: Models must have been fit prior to calling ``gen``.
NOTE: Some underlying models may ignore the ``n`` argument and produce a
model-determined number of arms. In that case this method will also output
a generator run with number of arms (that can differ from ``n``).
Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
model_spec = self.model_spec_to_gen_from
should_generate_run = True
Expand Down
Loading