Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Convert Params Models To Dictionary Before Assigning As Private Attribute In OBBject. #6492

Merged
merged 11 commits into from
Jun 13, 2024
8 changes: 5 additions & 3 deletions cli/openbb_cli/argparse_translator/obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def all(self) -> Dict[int, Dict]:
def _handle_standard_params(obbject: OBBject) -> str:
"""Handle standard params for obbjects"""
standard_params_json = ""
std_params = obbject._standard_params # pylint: disable=protected-access
if hasattr(std_params, "__dict__"):
std_params = getattr(
obbject, "_standard_params", {}
) # pylint: disable=protected-access
if std_params:
standard_params = {
k: str(v)[:30] for k, v in std_params.__dict__.items() if v
k: str(v)[:30] for k, v in std_params.items() if v and k != "data"
}
standard_params_json = json.dumps(standard_params)

Expand Down
2 changes: 1 addition & 1 deletion cli/tests/test_argparse_translator_obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def model_json_schema(self):
obb.extra = {"command": "test_command"}
obb._route = "/test/route"
obb._standard_params = Mock()
obb._standard_params.__dict__ = {}
obb._standard_params = {}
obb.results = [MockModel(1), MockModel(2)]
return obb

Expand Down
37 changes: 24 additions & 13 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,20 @@ def _chart(
raise OpenBBError(
"Charting is not installed. Please install `openbb-charting`."
)
# Here we will pop the chart_params kwargs and flatten them into the kwargs.
chart_params = {}
extra_params = kwargs.get("extra_params", {})
extra_params = getattr(obbject, "_extra_params", {})
deeleeramone marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(extra_params, "__dict__") and hasattr(
extra_params, "chart_params"
):
chart_params = kwargs["extra_params"].__dict__.get("chart_params", {})
elif isinstance(extra_params, dict) and "chart_params" in extra_params:
chart_params = kwargs["extra_params"].get("chart_params", {})
if extra_params and "chart_params" in extra_params:
chart_params = extra_params.get("chart_params", {})

if "chart_params" in kwargs and kwargs["chart_params"] is not None:
if kwargs.get("chart_params"):
chart_params.update(kwargs.pop("chart_params", {}))

# Verify that kwargs is not nested as kwargs so we don't miss any chart params.
if (
"kwargs" in kwargs
and "chart_params" in kwargs["kwargs"]
and kwargs["kwargs"].get("chart_params") is not None
and kwargs["kwargs"].get("chart_params")
):
chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {}))

Expand Down Expand Up @@ -308,9 +305,23 @@ async def _execute_func(

try:
obbject = await cls._command(func, kwargs)
# pylint: disable=protected-access
obbject._route = route
obbject._standard_params = kwargs.get("standard_params", None)

# This section prepares the obbject to pass to the charting service.
obbject._route = route # pylint: disable=protected-access
std_params = kwargs.get("standard_params", {})
if std_params and hasattr(std_params, "__dict__"):
std_params = std_params.__dict__
elif "data" in kwargs:
std_params = kwargs

xtra_params = kwargs.get("extra_params", {})
if xtra_params and hasattr(xtra_params, "__dict__"):
xtra_params = xtra_params.__dict__

obbject._standard_params = ( # pylint: disable=protected-access
std_params
)
obbject._extra_params = xtra_params # pylint: disable=protected-access
deeleeramone marked this conversation as resolved.
Show resolved Hide resolved
if chart and obbject.results:
cls._chart(obbject, **kwargs)
finally:
Expand Down
12 changes: 10 additions & 2 deletions openbb_platform/core/tests/app/test_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams
Expand Down Expand Up @@ -364,8 +365,15 @@ def __init__(self, results):

def test_static_command_runner_chart():
"""Test _chart method when charting is in obbject.accessors."""
mock_obbject = Mock()
mock_obbject.accessors = ["charting"]

mock_obbject = OBBject(
results=[
{"date": "1990", "value": 100},
{"date": "1991", "value": 200},
{"date": "1992", "value": 300},
],
accessors={"charting": Mock()},
)
mock_obbject.charting.show = Mock()

StaticCommandRunner._chart(mock_obbject) # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,23 +325,12 @@ def show(self, render: bool = True, **kwargs):
charting_function = self._get_chart_function(
self._obbject._route # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["provider"] = (
self._obbject.provider
) # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access

if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
kwargs["obbject_item"] = getattr(self._obbject, "results", [])
kwargs["charting_settings"] = getattr(self, "_charting_settings", {})
kwargs["standard_params"] = getattr(self._obbject, "_standard_params", {})
kwargs["extra_params"] = getattr(self._obbject, "extra_params", {})
kwargs["provider"] = getattr(self._obbject, "provider", "")
kwargs["extra"] = self._obbject.extra
fig, content = charting_function(**kwargs)
fig = self._set_chart_style(fig)
content = fig.show(external=True, **kwargs).to_plotly_json()
Expand Down Expand Up @@ -448,24 +437,12 @@ def to_chart(
kwargs["symbol"] = symbol
kwargs["target"] = target
kwargs["index"] = index
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["provider"] = self._obbject.provider # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access
metadata = kwargs["extra"].get("metadata")
kwargs["extra_params"] = (
metadata.arguments.get("extra_params") if metadata else None
)
if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
kwargs["obbject_item"] = getattr(self._obbject, "results", [])
kwargs["charting_settings"] = getattr(self, "_charting_settings", {})
kwargs["standard_params"] = getattr(self._obbject, "_standard_params", {})
kwargs["extra_params"] = getattr(self._obbject, "extra_params", {})
kwargs["provider"] = getattr(self._obbject, "provider", "")
deeleeramone marked this conversation as resolved.
Show resolved Hide resolved
kwargs["extra"] = self._obbject.extra
try:
if has_data:
self.show(data=data_as_df, render=render, **kwargs)
Expand All @@ -488,7 +465,7 @@ def to_chart(

def _set_chart_style(self, figure: Figure):
"""Set the user preference for light or dark mode."""
style = self._charting_settings.chart_style # pylint: disable=protected-access
style = self._charting_settings.chart_style
font_color = "black" if style == "light" else "white"
paper_bgcolor = "white" if style == "light" else "black"
figure = figure.update_layout(
Expand All @@ -498,7 +475,7 @@ def _set_chart_style(self, figure: Figure):
)
return figure

def toggle_chart_style(self): # pylint: disable=protected-access
def toggle_chart_style(self):
"""Toggle the chart style between light and dark mode."""
if not hasattr(self._obbject.chart, "fig"):
raise ValueError(
Expand Down
Loading