Skip to content

Commit

Permalink
Relax type hints on compute() (#2662)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2662

Scheduler does not have a GenerationStrategy, it has a GenerationStrategyInterface. Relaxing this type hint is important for creating Scheduler.compute_analyses(...)

Reviewed By: Cesar-Cardoso

Differential Revision: D61339432
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Aug 20, 2024
1 parent 528e69a commit 52959da
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pandas as pd
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.result import Err, ExceptionE, Ok, Result
Expand Down Expand Up @@ -87,7 +87,7 @@ class Analysis(Protocol):
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
) -> AnalysisCard:
# Note: when implementing compute always prefer experiment.lookup_data() to
# experiment.fetch_data() to avoid unintential data fetching within the report
Expand All @@ -97,7 +97,7 @@ def compute(
def compute_result(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
) -> Result[AnalysisCard, ExceptionE]:
"""
Utility method to compute an AnalysisCard as a Result. This can be useful for
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.core.generation_strategy_interface import GenerationStrategyInterface


class MarkdownAnalysisCard(AnalysisCard):
Expand All @@ -27,5 +27,5 @@ class MarkdownAnalysis(Analysis):
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
) -> MarkdownAnalysisCard: ...
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.objective import MultiObjective, ScalarizedObjective
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import graph_objects as go, io as pio


Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, metric_name: Optional[str] = None) -> None:
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ParallelCoordinatesPlot requires an Experiment")
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/plotly_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from plotly import graph_objects as go, io as pio


Expand All @@ -28,5 +28,5 @@ class PlotlyAnalysis(Analysis):
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
) -> PlotlyAnalysisCard: ...

0 comments on commit 52959da

Please sign in to comment.