Skip to content

Commit

Permalink
Merge pull request #286 from STALLAAAA/dev
Browse files Browse the repository at this point in the history
[MRG] add CLI of convert
  • Loading branch information
HuaizhengZhang committed Apr 21, 2021
2 parents bf1123e + e1b5c98 commit 72fe5ac
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 5 deletions.
14 changes: 14 additions & 0 deletions example/alexnet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
weight: "~/.modelci/Alexnext/pytorch-pytorch/image_classification/1.pth"
dataset: ImageNet
task: IMAGE_CLASSIFICATION
metric:
acc: 0.74
inputs:
- name: "input"
shape: [ 1, 3, 224, 224 ]
dtype: TYPE_FP32
outputs:
- name: "output"
shape: [ -1, 1000 ]
dtype: TYPE_FP32
convert: false
52 changes: 52 additions & 0 deletions modelci/cli/modelhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@
# permissions and limitations under the License.
from http import HTTPStatus
from pathlib import Path
from shutil import copy2, make_archive
from typing import Dict, List, Optional

import requests
import typer
import yaml
from pydantic import ValidationError
import modelci.persistence.service_ as ModelDB
from modelci.hub.manager import generate_model_family

from modelci.config import app_settings
from modelci.hub.utils import parse_path_plain
from modelci.types.models import Framework, Engine, IOShape, Task, Metric, ModelUpdateSchema
from modelci.types.models import MLModelFromYaml, MLModel
from modelci.types.models.common import ModelStatus
from modelci.ui import model_view, model_detailed_view
from modelci.utils import Logger
from modelci.utils.misc import remove_dict_null
Expand Down Expand Up @@ -242,3 +247,50 @@ def delete(model_id: str = typer.Argument(..., help='Model ID')):
with requests.delete(f'{app_settings.api_v1_prefix}/model/{model_id}') as r:
if r.status_code == HTTPStatus.NO_CONTENT:
typer.echo(f"Model {model_id} deleted")


@app.command('convert')
def convert(
id: str = typer.Option(None, '-i', '--id', help='ID of model.'),
yaml_file: Optional[Path] = typer.Option(
None, '-f', '--yaml-file', exists=True, file_okay=True,
help='Path to configuration YAML file. You should either set the `yaml_file` field or fields '
'(`FILE_OR_DIR`, `--name`, `--framework`, `--engine`, `--version`, `--task`, `--dataset`,'
'`--metric`, `--input`, `--output`).'
),
register: bool = typer.Option(False, '-r', '--register', is_flag=True, help='register the converted models to modelhub, default false')
):
if id is not None and yaml_file is not None:
typer.echo("Do not use -id and -path at the same time.")
typer.Exit()
elif id is not None and yaml_file is None:
if ModelDB.exists_by_id(id):
model = ModelDB.get_by_id(id)
else:
typer.echo(f"model id: {id} does not exist in modelhub")
elif id is None and yaml_file is not None:
# get MLModel from yaml file
with open(yaml_file) as f:
model_config = yaml.safe_load(f)
model_yaml = MLModelFromYaml.parse_obj(model_config)
model_in_saved_path = model_yaml.saved_path
if model_in_saved_path != model_yaml.weight:
copy2(model_yaml.weight, model_in_saved_path)
if model_yaml.engine == Engine.TFS:
weight_dir = model_yaml.weight
make_archive(weight_dir.with_suffix('.zip'), 'zip', weight_dir)

model_data = model_yaml.dict(exclude_none=True, exclude={'convert', 'profile'})
model = MLModel.parse_obj(model_data)

# auto execute all possible convert and return a list of save paths of every converted model
generated_dir_list = generate_model_family(model)
typer.echo(f"Converted models are save in: {generated_dir_list}")
if register:
model_data = model.dict(exclude={'weight', 'id', 'model_status', 'engine'})
for model_dir in generated_dir_list:
parse_result = parse_path_plain(model_dir)
engine = parse_result['engine']
model_cvt = MLModel(**model_data, weight=model_dir, engine=engine, model_status=[ModelStatus.CONVERTED])
ModelDB.save(model_cvt)
typer.echo(f"converted {engine} are successfully registered in Modelhub")
16 changes: 11 additions & 5 deletions modelci/hub/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from modelci.types.bo import Task, ModelVersion, Framework, ModelBO

__all__ = ['get_remote_model_weight', 'register_model', 'register_model_from_yaml', 'retrieve_model',
'retrieve_model_by_task', 'retrieve_model_by_parent_id']
'retrieve_model_by_task', 'retrieve_model_by_parent_id', 'generate_model_family']

from modelci.types.models.common import Engine, ModelStatus

Expand Down Expand Up @@ -72,7 +72,7 @@ def register_model(

# generate model family
if convert:
model_dir_list.extend(_generate_model_family(model))
model_dir_list.extend(generate_model_family(model))

# register
model_data = model.dict(exclude={'weight', 'id', 'model_status', 'engine'})
Expand Down Expand Up @@ -149,11 +149,17 @@ def register_model_from_yaml(file_path: Union[Path, str]):
register_model(model, convert=model_yaml.convert, profile=model_yaml.profile)


def _generate_model_family(
def generate_model_family(
model: MLModel,
max_batch_size: int = -1
):
net = load(model.saved_path)
model_weight_path = model.saved_path
if not Path(model.saved_path).exists():
(filepath, filename) = os.path.split(model.saved_path)
os.makedirs(filepath)
with open(model.saved_path, 'wb') as f:
f.write(model.weight.__bytes__())
net = load(model_weight_path)
build_saved_dir_from_engine = partial(
generate_path_plain,
**model.dict(include={'architecture', 'framework', 'task', 'version'}),
Expand Down Expand Up @@ -323,4 +329,4 @@ def retrieve_model_by_parent_id(parent_id: str) -> List[ModelBO]:
if len(models) == 0:
raise FileNotFoundError('Model not found!')

return models
return models

0 comments on commit 72fe5ac

Please sign in to comment.