Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
changes the way we inherit modules
Browse files Browse the repository at this point in the history
Because we want to allow people to transfert their NumpyModule
to a TorchModule so they can get access to TorchScript (see issue #42)
we need to change the way how we do the inheritance. Before it was

Module (reference to torch.nn.Module or our BaseModule)
  --> CustomModule (e.g. Ridge)

But that means when loading the library there is just one class CustomModule
that even inherits from torch.nn.Module or BaseModule depending if torch
is available on the machine. With one inheritance, it is hard switch between
the classes. Changing the base class is very hacky, so this is not a
good approach. There we create both classes when torch is present
(note BaseModule wase renamed to NumpyModule)

def factory_custom_module(base):
    class _CustomModule(base):
        ...
    # change name ...
    return _CustomModule

CustomNumpyModule = factory_custom_module(NumpyModule)
CustomTorchModule = factory_custom_module(torch.nn.Module)

if HAS_TORCH:
    CustomModule = CustomTorchModule
else:
    CustomModule = CustomNumpyModule
  • Loading branch information
agoscinski committed Mar 18, 2023
1 parent 727f9b0 commit 8385523
Showing 1 changed file with 176 additions and 26 deletions.
202 changes: 176 additions & 26 deletions src/equisolve/module.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,132 @@
import pickle

from abc import abstractmethod, ABCMeta

from typing import TypeVar
from typing import TypeVar, Dict
from collections import OrderedDict

from equistore import TensorMap
from equistore import TensorMap, TensorBlock

# Workaround for typing Self with inheritance for python <3.11
# see https://peps.python.org/pep-0673/
TModule = TypeVar("TModule", bound="Module")

class NumpyModule(metaclass=ABCMeta):
@abstractmethod
def forward(self, *args, **kwargs):
return

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

#def __setattr__(self, name, value):
# # TODO used to keep track of the parameters that
# have been set up
# # https://stackoverflow.com/a/54994818
# return

def state_dict(self) -> OrderedDict:
"""
All required parameteres to initialize a fitted module
"""

# PR COMMENT: not implemented, but should be analogous
# to torch.nn.Module by using __setattr_
return

# PR comment: Torch does not return a boolean but a
# torch.nn.modules.module._IncompatibleKeys
# need to think about this, if we keep bool
# or make something analogous
def load_state_dict(self, state_dict: OrderedDict) -> bool:
"""
Initialize a fitted module
"""
# PR COMMENT: not implemented, but should be analogous
# to torch.nn.Module by using __setattr_
return

try:
import torch
HAS_TORCH = True
Module = torch.nn.Module
except:
class Module(metaclass=ABCMeta):
@abstractmethod
def forward(self, *args, **kwargs):
return
HAS_TORCH = False
Module = NumpyModule

def __call__(self, *args, **kwargs):
self.forward(*args, **kwargs)

# PR COMMENT:
# We want to allow people to switch to a torch model if needed
# (to use the TorchScript for low level integration into MD code).
# For that we need both class definitions (numpy and torch)
# so we can convert a NumpyModule to a torch Module
# the use case I see is that you did not work with torch
# then you have a class based on NumpyModule, but you need
# a class based on TorchModule, so we provide an utility function
# that takes all the parameters (state_dict) of the numpy module,
# converts them to torch tensors, and reinitialize it with
# the corresponding torch module

TEstimatorModule = TypeVar("TEstimatorModule", bound="EstimatorModule")
# all modules that have a corresponding torch module
# are stored here
NUMPY_TO_TORCH_MODULE = {}

class EstimatorModule(Module, metaclass=ABCMeta):
def forward(self, X: TensorMap):
return self.predict(X)
# PR COMMENT I think this cannot be hiden a class decorator like

@abstractmethod
def fit(self, X: TensorMap, y: TensorMap) -> TEstimatorModule:
return
# # automatically creates EstimatorModuleNumpy and EstimatorModuleTorch
# @equisolve.module
# class EstimatorModule(Module)
# ...

@abstractmethod
def predict(self, X: TensorMap) -> TensorMap:
return
# because we would need to change the base class
# and that is super hacky https://stackoverflow.com/a/9541560

@abstractmethod
def score(self, X: TensorMap, y: TensorMap) -> TensorMap:
return
# PR COMMENT all modules that support both would need to be build like this
def estimator_module_factory(base_class, name):
class _EstimatorModule(base_class, metaclass=ABCMeta):
def __init__(self):
super().__init__()

def fit_score(self, X: TensorMap, y: TensorMap = None) -> TensorMap:
self.fit(X, y)
return self.score(X, y)
def forward(self, X: TensorMap):
return self.predict(X)

@abstractmethod
def fit(self, X: TensorMap, y: TensorMap) -> base_class:
return

TTransformerModule = TypeVar("TTransformerModule", bound="TransformerModule")
@abstractmethod
def predict(self, X: TensorMap) -> TensorMap:
return

@abstractmethod
def score(self, X: TensorMap, y: TensorMap) -> TensorMap:
return

def fit_score(self, X: TensorMap, y: TensorMap = None) -> TensorMap:
self.fit(X, y)
return self.score(X, y)

# PR COMMENT: this is kind of the idea
# this can be solved more Vnicely, currently the class stored in
# equisolve.module.estimator_module_factory.<locals>.EstimatorModule
# we can replace https://stackoverflow.com/q/681953
_EstimatorModule.__name__ = name
_EstimatorModule.__qualname__ = name
return _EstimatorModule

EstimatorNumpyModule = estimator_module_factory(NumpyModule, "EstimatorModuleNumpy")
if HAS_TORCH:
EstimatorTorchModule = estimator_module_factory(torch.nn.Module, "EstimatorModuleTorch")
NUMPY_TO_TORCH_MODULE["EstimatorNumpyModule"] = EstimatorTorchModule
# this is just reference to the default module type
EstimatorModule = EstimatorTorchModule
else:
# this is just reference to the default module type
EstimatorModule = EstimatorNumpyModule


# PR COMMENT: for now I did not adapt the transformer module
# it is equivalent, but not worth the effort if we dont follow approach
TTransformerModule = TypeVar("TTransformerModule", bound="TransformerModule")

class TransformerModule(Module, metaclass=ABCMeta):
def forward(self, X: TensorMap):
Expand All @@ -62,3 +143,72 @@ def transform(self, X: TensorMap) -> TensorMap:
def fit_transform(self, X: TensorMap, y: TensorMap = None) -> TensorMap:
self.fit(X, y)
return self.transform(X)

def save(module: Module, f: str):
"""
Saves to a pickable object
"""
# PR COMMMENT:
# torch.save is just nice wrapper for pickle as far as I understand.
# It can also save TorchScripts by dispatching to jit.save,
# but that is not really intended. So we can just use it here
if HAS_TORCH:
torch.save(f)
else:
with open(f, 'wb') as file:
module = pickle.dump(file)


def load(f: str):
"""
Loads a pickable object
"""
if HAS_TORCH:
torch.load(f)
else:
with open(f, 'rb') as file:
module = pickle.load(file)

if HAS_TORCH:
# PR COMMENT: Maybe we can store also the init args and kwargs
# in our models so we can get them from module input
# torch does not do it as far as I have seen, so
# I did not implement it so far
def convert_to_torch_module(module: NumpyModule, *module_init_args, **module_init_kwargs):
"""
converts a numpy module to a torch module
"""
module.state_dict()
if module.__class__.__name__ not in NUMPY_TO_TORCH_MODULE.keys():
raise NotImplemented(f"Your module {module.__class__.__name__} "
"has not been implemented as torch module.")
torch_module_class = NUMPY_TO_TORCH_MODULE(module.__class__.__name__)
torch_module = torch_module_class(*module_init_args, **module_init_kwargs)
state_dict_numpy = module.state_dict()
state_dict_torch = {}
# PR COMMENT: that is the idea, this code is not working,
# also this might a bit more complicated if nested objects exist
# need to read more in detail how state_dict works
# https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html
for key, arg in state_dict_numpy.items():
if isinstance(arg, TensorMap):
state_dict_torch[key] = argument.to_tensor()
if isinstance(arg, TensorBlock):
state_dict_torch[key] = argument.to_tensor()
torch_module.load_state_dict(state_dict_torch)
return torch_module

def save_torch_script(module: Module):
# traces module saves it
if not(has_torch):
raise importerror("saving a model as torchscript requires torch import.")
if issubclass(module, torch.nn.module):
torch.jit.save(module)
elif issubclass(module, Module):
raise ValueError(f"Your module of type {module.__class__} is not a torch module and cannot be saved as a torch script. "
"Please convert your module first using convert_to_torch_module.")
else:
raise ValueError(f"Your module of type {module.__class__} is not a torch module and cannot be saved as a torch script.")

def load_torch_script(f: str):
return torch.jit.load(f)

0 comments on commit 8385523

Please sign in to comment.