Skip to content

Commit

Permalink
clean up + refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
seebi committed Apr 9, 2024
1 parent d428fdd commit 09c3bb9
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 66 deletions.
6 changes: 2 additions & 4 deletions .idea/imr.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 24 additions & 18 deletions imr/imr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Inteligent Model Registry (imr)
import os
"""Intelligent Model Registry (imr)"""
import shutil
from pathlib import Path

Expand Down Expand Up @@ -39,12 +38,12 @@ def pull(self, directory: str, package: str, version: str = "latest") -> None:
auth=(self.user, self.password),
auth_type=HTTPBasicAuth,
)
file_path = directory + "/" + package + "/" + version
file_path: Path = Path(directory) / package / version
if not Path(file_path).exists():
Path.mkdir(file_path, parents=True)
with path.open() as fd, Path.open(file_path + "/" + "model.zip", "wb") as out:
with path.open() as fd, Path.open(file_path / "model.zip", "wb") as out:
out.write(fd.read())
shutil.unpack_archive(file_path + "/" + "model.zip", file_path + "/model")
shutil.unpack_archive(file_path / "model.zip", file_path / "model")

def rm(self, package: str, version: str = "latest") -> None:
"""Remove a model from the remote repository."""
Expand All @@ -59,30 +58,37 @@ def rm(self, package: str, version: str = "latest") -> None:
class IMRLocal:
"""Local repository class."""

repo: Path
home: Path

def __init__(self, repo: str | None = None):
self.home = Path.home()
if repo is None:
self.repo = str(self.home) + "/.imr"
self.repo = self.home / ".imr"
else:
self.repo = repo
self.repo = Path(repo)
if not Path(self.repo).exists():
Path.mkdir(self.repo, parents=True)
Path(self.repo).mkdir(parents=True)

def list(self) -> list:
"""List local models."""
return [
entry[0].replace(self.repo + "/", "")
for entry in os.walk(self.repo)
if len(entry[1]) == 0 and entry[0] not in self.repo
]
def list(self) -> list[str]:
"""List local models.
returns all the paths in the repository after main directory
BASE/modela/version1
BASE/modela/version2
BASE/modelb/version3
results in ["modela/version1", "modela/version2", "modelb/version3"]
"""
return ["/".join(version.parts[-2:]) for version in self.repo.rglob("*/*/")]

def push(self, directory: str, package: str, version: str = "latest") -> None:
"""Push a model in a directory to the local repository."""
shutil.copytree(directory, self.repo + "/" + package + "/" + version)
shutil.copytree(src=directory, dst=self.repo / package / version)

def rm(self, package: str, version: str = "latest") -> None:
"""Remove a model from the local repository."""
if version is None:
shutil.rmtree(self.repo + "/" + package)
shutil.rmtree(self.repo / package)
else:
shutil.rmtree(self.repo + "/" + package + "/" + version)
shutil.rmtree(self.repo / package / version)
73 changes: 37 additions & 36 deletions imr/imr_cli.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from pathlib import Path
""" Path module.
"""Path module.
For handling local files and directory operations.
"""

from pathlib import Path

import click
import yaml

from imr import IMRLocal, IMRRemote
from imr.imr import IMRLocal, IMRRemote

home = Path.home()
imr_dir = home + "/.imr"
imr_local: IMRLocal = IMRLocal(imr_dir)
imr_remote: IMRRemote = None
imr_config = None

def load_params() -> None:
"""Load the default parameters from conf.yaml file."""
with Path.open(imr_dir + "/config.yaml") as stream:
imr_config = yaml.safe_load(stream)
class Context:
"""The context for all CLI commands."""

home = Path.home()
imr_dir = home / ".imr"
imr_local: IMRLocal = IMRLocal(str(imr_dir))
imr_remote: IMRRemote


@click.group()
def cli() -> None:
@click.pass_context
def cli(ctx: click.Context) -> None:
"""Get the cli command options."""
ctx.obj = Context()
# add loadParams() later


Expand All @@ -33,48 +33,52 @@ def local() -> None:


@local.command("list")
def list_local() -> None:
@click.pass_obj
def list_local(obj: Context) -> None:
"""List local packages."""
for package in imr_local.list():
print(package)
for package in obj.imr_local.list():
click.echo(package)


@local.command()
@click.argument("package")
@click.option(
"-v", "--version", type=str, default="latest", help="version of the model.", show_default=True
)
def remove(package: str, version: str) -> None:
@click.pass_obj
def remove(obj: Context, package: str, version: str) -> None:
"""Remove local packages."""
local.rm(package, version)
obj.imr_local.rm(package, version)


@cli.group()
@click.argument("host")
@click.argument("user")
@click.argument("password")
def remote(host: str, user: str, password: str) -> None:
@click.pass_obj
def remote(obj: Context, host: str, user: str, password: str) -> None:
"""Get remote command cli options."""
global imr_remote
imr_remote = IMRRemote(host, user, password)
obj.imr_remote = IMRRemote(host, user, password)


@remote.command("list")
def list_remote() -> None:
@click.pass_obj
def list_remote(obj: Context) -> None:
"""List remote packages."""
packages = imr_remote.list()
packages = obj.imr_remote.list()
for p in packages:
print(p)
click.echo(p)


@remote.command()
@click.argument("package")
@click.option(
"-v", "--version", type=str, default="latest", help="version of the model.", show_default=True
)
def rm(package: str, version: str) -> None:
@click.pass_obj
def rm(obj: Context, package: str, version: str) -> None:
"""Remove remote package."""
imr_remote.rm(package, version)
obj.imr_remote.rm(package, version)


@remote.command()
Expand All @@ -83,9 +87,10 @@ def rm(package: str, version: str) -> None:
@click.option(
"-v", "--version", type=str, default="latest", help="version of the model.", show_default=True
)
def push(model_dir: str, package: str, version: str) -> None:
@click.pass_obj
def push(obj: Context, model_dir: str, package: str, version: str) -> None:
"""Push model to remote repository."""
imr_remote.push(model_dir, package, version)
obj.imr_remote.push(model_dir, package, version)


@remote.command()
Expand All @@ -94,17 +99,13 @@ def push(model_dir: str, package: str, version: str) -> None:
"-d",
"--dir",
type=str,
default=imr_dir,
help="directory to pull the model in.",
show_default=True,
)
@click.option(
"-v", "--version", type=str, default="latest", help="version of the model.", show_default=True
)
def pull(package: str, model_dir: str, version: str) -> None:
@click.pass_obj
def pull(obj: Context, package: str, model_dir: str, version: str) -> None:
"""Pull model from remote repository."""
imr_remote.pull(model_dir, package, version)


if __name__ == "__main__":
cli()
obj.imr_remote.pull(model_dir, package, version)
11 changes: 4 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README-public.md"
homepage = "https://github.com/eccenca/imr"

[tool.poetry.scripts]
imr = 'imr-cli:main'
imr = 'imr.imr_cli:cli'

[tool.poetry.dependencies]
# if you need to change python version here, change it also in .python-version
Expand Down
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""tests"""

from pathlib import Path

FIXTURE_DIR = Path(__file__).parent / "fixtures"
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
15 changes: 15 additions & 0 deletions tests/test_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Repository tests."""

from imr.imr import IMRLocal
from tests import FIXTURE_DIR


def test_basic_list() -> None:
"""Test list items in repository."""
number_of_models = 5
directory = FIXTURE_DIR / "repository"
repository = IMRLocal(repo=str(directory))
models = repository.list()
assert len(models) == number_of_models
assert "modela/version1" in models
assert "ignore-me" not in models

0 comments on commit 09c3bb9

Please sign in to comment.