Skip to content

Commit

Permalink
Merge pull request #288 from univerone/add-model-retrieving
Browse files Browse the repository at this point in the history
[MRG]Fix model retrieving from modelhub
  • Loading branch information
HuaizhengZhang committed Apr 15, 2021
2 parents d20f6d7 + a14e290 commit 8ba8605
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions modelci/types/models/mlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,61 @@
"""
import getpass
import os
import gridfs
from datetime import datetime
from pathlib import Path
from typing import Union, Optional, Dict, List, Any

from bson import ObjectId
from gridfs import GridOut
from pydantic import BaseModel, FilePath, DirectoryPath, PositiveInt, Field, root_validator

from .common import Metric, IOShape, Framework, Engine, Task, ModelStatus, Status, PydanticObjectId, \
named_enum_json_encoder
from .pattern import as_form
from ...hub.utils import parse_path_plain, generate_path_plain
from modelci.config import db_settings
from modelci.experimental.mongo_client import MongoClient

_db = MongoClient()[db_settings.mongo_db]
_fs = gridfs.GridFS(_db)


class Weight(BaseModel):
"""TODO: Only works for MLModelIn"""

__slots__ = ('file',)
__slots__ = ('file', '_gridfs_out')

__root__: Optional[PydanticObjectId]

def __init__(self, __root__):
if isinstance(__root__, Path):
object.__setattr__(self, 'file', FilePath.validate(__root__))
object.__setattr__(self, '_gridfs_out', None)
__root__ = None

self._grid_out: Optional[GridOut]
if isinstance(__root__, ObjectId):
if _fs.exists(__root__):
object.__setattr__(self, 'file', None)
object.__setattr__(self, '_gridfs_out', _fs.get(__root__))

super().__init__(__root__=__root__)

@property
def filename(self):
if self.file:
return self.file.name
return ''
elif self._gridfs_out:
return self._gridfs_out.filename
else:
return ''

def __bytes__(self):
if self.file:
return self.file.read_bytes()
elif self._gridfs_out:
return self._gridfs_out.read()
else:
return b''


@as_form
Expand Down Expand Up @@ -85,7 +101,7 @@ def dict(self, use_enum_values: bool = False, **kwargs):
# fix metric key as a Enum
metric: dict = data.get('metric', None)
if metric:
data['metric'] = {Metric(k).name: v for k, v in metric.items() }
data['metric'] = {Metric(k).name: v for k, v in metric.items()}

return data

Expand Down Expand Up @@ -182,4 +198,4 @@ class ModelUpdateSchema(BaseModel):
outputs: Optional[List[IOShape]] = Field(
default_factory=list,
example='[{"name": "output", "shape": [-1, 1000], "dtype": "TYPE_FP32"}]'
)
)

0 comments on commit 8ba8605

Please sign in to comment.