Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Convert Params Models To Dictionary Before Assigning As Private Attribute In OBBject. #6492

Merged
merged 11 commits into from
Jun 13, 2024
8 changes: 5 additions & 3 deletions cli/openbb_cli/argparse_translator/obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def all(self) -> Dict[int, Dict]:
def _handle_standard_params(obbject: OBBject) -> str:
"""Handle standard params for obbjects"""
standard_params_json = ""
std_params = obbject._standard_params # pylint: disable=protected-access
if hasattr(std_params, "__dict__"):
std_params = getattr(
obbject, "_standard_params", {}
) # pylint: disable=protected-access
if std_params:
standard_params = {
k: str(v)[:30] for k, v in std_params.__dict__.items() if v
k: str(v)[:30] for k, v in std_params.items() if v and k != "data"
}
standard_params_json = json.dumps(standard_params)

Expand Down
2 changes: 1 addition & 1 deletion cli/tests/test_argparse_translator_obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def model_json_schema(self):
obb.extra = {"command": "test_command"}
obb._route = "/test/route"
obb._standard_params = Mock()
obb._standard_params.__dict__ = {}
obb._standard_params = {}
obb.results = [MockModel(1), MockModel(2)]
return obb

Expand Down
40 changes: 26 additions & 14 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,20 @@ def _chart(
raise OpenBBError(
"Charting is not installed. Please install `openbb-charting`."
)
# Here we will pop the chart_params kwargs and flatten them into the kwargs.
chart_params = {}
extra_params = kwargs.get("extra_params", {})
extra_params = getattr(obbject, "_extra_params", {})
deeleeramone marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(extra_params, "__dict__") and hasattr(
extra_params, "chart_params"
):
chart_params = kwargs["extra_params"].__dict__.get("chart_params", {})
elif isinstance(extra_params, dict) and "chart_params" in extra_params:
chart_params = kwargs["extra_params"].get("chart_params", {})
if extra_params and "chart_params" in extra_params:
chart_params = extra_params.get("chart_params", {})

if "chart_params" in kwargs and kwargs["chart_params"] is not None:
if kwargs.get("chart_params"):
chart_params.update(kwargs.pop("chart_params", {}))

# Verify that kwargs is not nested as kwargs so we don't miss any chart params.
if (
"kwargs" in kwargs
and "chart_params" in kwargs["kwargs"]
and kwargs["kwargs"].get("chart_params") is not None
and kwargs["kwargs"].get("chart_params")
):
chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {}))

Expand Down Expand Up @@ -307,10 +304,25 @@ async def _execute_func(
} or None

try:
obbject = await cls._command(func, kwargs)
# pylint: disable=protected-access
obbject._route = route
obbject._standard_params = kwargs.get("standard_params", None)
obbject = await cls._command(
func, kwargs
) # pylint: disable=protected-access
# This section prepares the obbject to pass to the charting service.
obbject._route = route # pylint: disable=protected-access
std_params = kwargs.get("standard_params", {})
if std_params and hasattr(std_params, "__dict__"):
std_params = std_params.__dict__
elif "data" in kwargs:
std_params = kwargs

xtra_params = kwargs.get("extra_params", {})
if xtra_params and hasattr(xtra_params, "__dict__"):
xtra_params = xtra_params.__dict__

obbject._standard_params = (
std_params # pylint: disable=protected-access
)
obbject._extra_params = xtra_params # pylint: disable=protected-access
hjoaquim marked this conversation as resolved.
Show resolved Hide resolved
if chart and obbject.results:
cls._chart(obbject, **kwargs)
finally:
Expand Down
12 changes: 10 additions & 2 deletions openbb_platform/core/tests/app/test_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams
Expand Down Expand Up @@ -364,8 +365,15 @@ def __init__(self, results):

def test_static_command_runner_chart():
"""Test _chart method when charting is in obbject.accessors."""
mock_obbject = Mock()
mock_obbject.accessors = ["charting"]

mock_obbject = OBBject(
results=[
{"date": "1990", "value": 100},
{"date": "1991", "value": 200},
{"date": "1992", "value": 300},
],
accessors={"charting": Mock()},
)
mock_obbject.charting.show = Mock()

StaticCommandRunner._chart(mock_obbject) # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,23 +325,22 @@ def show(self, render: bool = True, **kwargs):
charting_function = self._get_chart_function(
self._obbject._route # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["provider"] = (
self._obbject.provider
kwargs["obbject_item"] = getattr(
self._obbject, "results", []
) # pylint: disable=protected-access
kwargs["charting_settings"] = getattr(
self, "_charting_settings", {}
) # pylint: disable=protected-access
kwargs["standard_params"] = getattr(
self._obbject, "_standard_params", {}
) # pylint: disable=protected-access
kwargs["extra_params"] = getattr(
self._obbject, "extra_params", {}
) # pylint: disable=protected-access
kwargs["provider"] = getattr(
self._obbject, "provider", ""
) # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access

if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
fig, content = charting_function(**kwargs)
fig = self._set_chart_style(fig)
content = fig.show(external=True, **kwargs).to_plotly_json()
Expand Down Expand Up @@ -448,24 +447,22 @@ def to_chart(
kwargs["symbol"] = symbol
kwargs["target"] = target
kwargs["index"] = index
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["provider"] = self._obbject.provider # pylint: disable=protected-access
kwargs["obbject_item"] = getattr(
self._obbject, "results", []
) # pylint: disable=protected-access
kwargs["charting_settings"] = getattr(
self, "_charting_settings", {}
) # pylint: disable=protected-access
kwargs["standard_params"] = getattr(
self._obbject, "_standard_params", {}
) # pylint: disable=protected-access
kwargs["extra_params"] = getattr(
self._obbject, "extra_params", {}
) # pylint: disable=protected-access
kwargs["provider"] = getattr(
self._obbject, "provider", ""
) # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access
metadata = kwargs["extra"].get("metadata")
kwargs["extra_params"] = (
metadata.arguments.get("extra_params") if metadata else None
)
if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
try:
if has_data:
self.show(data=data_as_df, render=render, **kwargs)
Expand Down
Loading