Skip to content

Commit

Permalink
Add method to remove state when serialising model
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Gonzalez-Martin committed Apr 6, 2021
1 parent 19a1efd commit 78d3f15
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
12 changes: 12 additions & 0 deletions tempo/serve/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ def loadmethod(self, load_func: LoadMethodSignature) -> LoadMethodSignature:

return load_func

def __getstate__(self) -> dict:
"""
__getstate__ gets called by pickle before serialising an object to get
its internal representation.
We override __getstate__ to make sure that the model's internal context
is not pickled with the object.
"""
state = self.__dict__.copy()
state["context"] = SimpleNamespace()

return state

@classmethod
def load(cls, folder: str) -> "BaseModel":
file_path_pkl = os.path.join(folder, DefaultModelFilename)
Expand Down
16 changes: 14 additions & 2 deletions tests/serve/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def custom_model(a: np.ndarray) -> np.ndarray:
#
# Test lambda function
#
@pytest.mark.parametrize("input, expected", [(np.array([[0, 0, 0, 1]]), np.array([[0, 0, 1]]))])
@pytest.mark.parametrize(
"input, expected", [(np.array([[0, 0, 0, 1]]), np.array([[0, 0, 1]]))]
)
def test_lambda(input, expected):
model = Model(
name="test-iris-sklearn",
Expand Down Expand Up @@ -158,7 +160,9 @@ def test_custom_multiheaded_model_tuple(v2_input, expected):
name="multi-headed",
platform=ModelFramework.Custom,
)
def custom_multiheaded_model_tuple(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
def custom_multiheaded_model_tuple(
a: np.ndarray, b: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
return a, b

response = custom_multiheaded_model_tuple.request(v2_input)
Expand Down Expand Up @@ -236,3 +240,11 @@ def predict(self, X: str) -> str:
def test_custom_loadmethod(custom_model):
pred = custom_model(payload=np.array([1, 2, 3]))
assert pred == np.array([6])


def test_model_save(custom_model: Model):
custom_model.save(save_env=False)
loaded = Model.load(custom_model.details.local_folder)

assert len(custom_model.context.__dict__) > 0
assert len(loaded.context.__dict__) == 0

0 comments on commit 78d3f15

Please sign in to comment.