Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make command orion db setup ask for right arguments based on storage backend. #586

Merged
merged 13 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/orion/core/cli/db/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import yaml

import orion.core
from orion.core.io.database import Database
from orion.core.utils.terminal import ask_question

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,11 +54,23 @@ def main(*args):
if cancel.strip().lower() == "n":
return

_type = ask_question("Enter the database type: ", "mongodb")
name = ask_question("Enter the database name: ", "test")
host = ask_question("Enter the database host: ", "localhost")

config = {"database": {"type": _type, "name": name, "host": host}}
# Get database type.
_type = ask_question(
"Enter the database",
choice=Database.typenames,
default="mongodb",
ignore_case=True,
).lower()
# Get database arguments.
db_class = Database.types[Database.typenames.index(_type)]
db_args = db_class.get_defaults()
arg_vals = {}
for arg_name, default_value in sorted(db_args.items()):
arg_vals[arg_name] = ask_question(
"Enter the database {}: ".format(arg_name), default_value
)

config = {"database": {"type": _type, **arg_vals}}

print("Default configuration file will be saved at: ")
print(default_file)
Expand Down
18 changes: 18 additions & 0 deletions src/orion/core/io/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(
**kwargs
):
"""Init method, see attributes of :class:`AbstractDB`."""
defaults = self.get_defaults()
host = defaults.get("host", None) if host is None or host == "" else host
name = defaults.get("name", None) if name is None or name == "" else name

self.host = host
self.name = name
self.port = port
Expand Down Expand Up @@ -266,6 +270,20 @@ def remove(self, collection_name, query):
"""
pass

@classmethod
@abstractmethod
def get_defaults(cls):
"""Get database arguments needed to create a database instance.

Returns
-------
dict
A dictionary mapping an argument name to a default value.
If unexpected, default value can be None.

"""
pass


# pylint: disable=too-few-public-methods
class ReadOnlyDB(object):
Expand Down
10 changes: 10 additions & 0 deletions src/orion/core/io/database/ephemeraldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def remove(self, collection_name, query):

return dbcollection.delete_many(query=query)

@classmethod
def get_defaults(cls):
"""Get database arguments needed to create a database instance.

.. seealso:: :meth:`orion.core.io.database.AbstractDB.get_defaults`
for argument documentation.

"""
return {}


class EphemeralCollection(object):
"""Non permanent collection
Expand Down
10 changes: 10 additions & 0 deletions src/orion/core/io/database/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,13 @@ def _sanitize_attrs(self):
self.options["authSource"] = settings["options"].get(
"authsource", self.name
)

@classmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could rename this defaults instead and use it inside __init__ to make sure the defaults are coherent. Otherwise we are starting to duplicate the defaults and they may diverge. (defaults here vs defaults in orion.config.storage.database)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand here, what should be renamed to defaults ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, the name of the class method, get_defaults.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! Renamed get_arguments() to get_defaults() (in derived classes, too).

How should I use it inside __init__ ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside the init, if a value is set to None or is an empty string, replace with value in get_defaults. Also you will need to use the default values as currently set in config here: https://github.com/Epistimio/orion/blob/develop/src/orion/core/__init__.py#L91. This way we harmonize the defaults.
See @abergeron 's PR for the defaults on PickledDB: https://github.com/Epistimio/orion/pull/585/files#diff-f4fa1b61f95de078bc5673f3d914867b67cef706344b651be8ee5cbeb2d17ad7R102

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! Done !

def get_defaults(cls):
"""Get database arguments needed to create a database instance.

.. seealso:: :meth:`orion.core.io.database.AbstractDB.get_defaults`
for argument documentation.

"""
return {"name": "orion", "host": "localhost"}
10 changes: 10 additions & 0 deletions src/orion/core/io/database/pickleddb.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ def locked_database(self, write=True):
except Timeout as e:
raise DatabaseTimeout(TIMEOUT_ERROR_MESSAGE.format(self.timeout)) from e

@classmethod
def get_defaults(cls):
"""Get database arguments needed to create a database instance.

.. seealso:: :meth:`orion.core.io.database.AbstractDB.get_defaults`
for argument documentation.

"""
return {"host": DEFAULT_HOST}


local_file_systems = ["ext2", "ext3", "ext4", "ntfs"]

Expand Down
30 changes: 25 additions & 5 deletions src/orion/core/utils/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""


def ask_question(question, default=None):
def ask_question(question, default=None, choice=None, ignore_case=False):
"""Ask a question to the user and receive an answer.

Parameters
Expand All @@ -15,20 +15,40 @@ def ask_question(question, default=None):
The question to be asked.
default: str
The default value to use if the user enters nothing.
choice: list
List of expected values to check user answer
ignore_case: bool
Used only if choice is provided. If True, ignore case when checking
user answer against given choice.

Returns
-------
str
The answer provided by the user.

"""
if choice is not None:
if ignore_case:
choice = [value.lower() for value in choice]
question = question + " (choice: {})".format(", ".join(choice))

if default is not None:
question = question + " (default: {}) ".format(default)

answer = input(question)

if answer.strip() == "":
return default
while True:
answer = input(question)
if answer.strip() == "":
answer = default
break
if choice is None:
break
if answer in choice or (ignore_case and answer.lower() in choice):
break
print(
"Unexpected value: {}. Must be one of: {}\n".format(
answer, ", ".join(choice)
)
)

return answer

Expand Down
83 changes: 79 additions & 4 deletions tests/functional/commands/test_setup_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import orion.core
import orion.core.cli
from orion.core.io.database import Database


class _mock_input:
Expand All @@ -26,7 +27,7 @@ def test_creation_when_not_existing(monkeypatch, tmp_path):
"""Test if a configuration file is created when it does not exist."""
bouthilx marked this conversation as resolved.
Show resolved Hide resolved
config_path = str(tmp_path) + "/tmp_config.yaml"
monkeypatch.setattr(orion.core, "DEF_CONFIG_FILES_PATHS", [config_path])
monkeypatch.setattr(builtins, "input", _mock_input(["type", "name", "host"]))
monkeypatch.setattr(builtins, "input", _mock_input(["mongodb", "host", "name"]))

try:
os.remove(config_path)
Expand All @@ -40,14 +41,16 @@ def test_creation_when_not_existing(monkeypatch, tmp_path):
with open(config_path, "r") as output:
content = yaml.safe_load(output)

assert content == {"database": {"type": "type", "name": "name", "host": "host"}}
assert content == {"database": {"type": "mongodb", "name": "name", "host": "host"}}


def test_creation_when_exists(monkeypatch, tmp_path):
"""Test if the configuration file is overwritten when it exists."""
config_path = str(tmp_path) + "/tmp_config.yaml"
monkeypatch.setattr(orion.core, "DEF_CONFIG_FILES_PATHS", [config_path])
monkeypatch.setattr(builtins, "input", _mock_input(["y", "type", "name", "host"]))
monkeypatch.setattr(
builtins, "input", _mock_input(["y", "mongodb", "host", "name"])
)

dump = {"database": {"type": "allo2", "name": "allo2", "host": "allo2"}}

Expand Down Expand Up @@ -81,6 +84,48 @@ def test_stop_creation_when_exists(monkeypatch, tmp_path):
assert content == dump


def test_invalid_database(monkeypatch, tmp_path, capsys):
"""Test if command prompt loops when invalid database is typed."""
invalid_db_names = [
"invalid database",
"invalid database again",
"2383ejdd",
"another invalid database",
]
config_path = str(tmp_path) + "/tmp_config.yaml"
monkeypatch.setattr(orion.core, "DEF_CONFIG_FILES_PATHS", [config_path])
monkeypatch.setattr(
builtins,
"input",
_mock_input(
[
*invalid_db_names,
"mongodb",
"the host",
"the name",
]
),
)

bouthilx marked this conversation as resolved.
Show resolved Hide resolved
orion.core.cli.main(["db", "setup"])

with open(config_path, "r") as output:
content = yaml.safe_load(output)

assert content == {
"database": {"type": "mongodb", "name": "the name", "host": "the host"}
}

captured_output = capsys.readouterr().out
for invalid_db_name in invalid_db_names:
assert (
"Unexpected value: {}. Must be one of: {}\n".format(
invalid_db_name, ", ".join(Database.typenames)
)
in captured_output
)


def test_defaults(monkeypatch, tmp_path):
"""Test if the default values are used when nothing user enters nothing."""
config_path = str(tmp_path) + "/tmp_config.yaml"
Expand All @@ -93,5 +138,35 @@ def test_defaults(monkeypatch, tmp_path):
content = yaml.safe_load(output)

assert content == {
"database": {"type": "mongodb", "name": "test", "host": "localhost"}
"database": {"type": "mongodb", "name": "orion", "host": "localhost"}
}


def test_ephemeraldb(monkeypatch, tmp_path):
"""Test if config content is written for an ephemeraldb."""
config_path = str(tmp_path) + "/tmp_config.yaml"
monkeypatch.setattr(orion.core, "DEF_CONFIG_FILES_PATHS", [config_path])
monkeypatch.setattr(builtins, "input", _mock_input(["ephemeraldb"]))

orion.core.cli.main(["db", "setup"])

with open(config_path, "r") as output:
content = yaml.safe_load(output)

assert content == {"database": {"type": "ephemeraldb"}}


def test_pickleddb(monkeypatch, tmp_path):
"""Test if config content is written for an pickleddb."""
host = "my_pickles.db"

config_path = str(tmp_path) + "/tmp_config.yaml"
monkeypatch.setattr(orion.core, "DEF_CONFIG_FILES_PATHS", [config_path])
monkeypatch.setattr(builtins, "input", _mock_input(["pickleddb", host]))

orion.core.cli.main(["db", "setup"])

with open(config_path, "r") as output:
content = yaml.safe_load(output)

assert content == {"database": {"type": "pickleddb", "host": host}}