-
Notifications
You must be signed in to change notification settings - Fork 19
Description
In our context, an ensemble, which is constructed from a set of emulators, can be viewed as an emulator itself, where the uncertainty is aggregated from the emulators it holds. Both the AutoEmulate object and an ensemble operate over a set of emulators and have some overlapping functionalities (e.g. emulator fitting) and some separate functionalities (e.g. emulator comparison). Overall, I think AutoEmulate
can be cast as a subclass of Ensemble
, but happy to discuss this point. See elaboration below.
Ensemble
Emulator.predict(x)
returns (mean, covariance)
, possibly different ones for multiple calls over a given x
and possibly with a zero covariance
. Let M
be the number of emulators in the ensemble and N
be the number of samples (calls) per emulator. Then the (mean, covariance)
of the ensemble would be
where the first term of
AutoEmulate/Ensemble
Both objects share functionalities, e.g. training/fitting emulators. Perhaps, we should break up some of AutoEmulate
's methods to:
- maximise the number of methods shared with
Ensemble
through inheritance, e.g., pulling training/fitting out ofAutoEmulate.compare
and sharingAutoEmulate.fit(X, Y)
andEnsemble.fit(X, Y)
; - and minimize new methods specific to
AutoEmulate
.
Psuedo-code
class Emulator:
def predict(self, x):
..
return mean, covariance
@dataclass
class Ensemble(Emulator):
# these emulators could be ensembles
emulators: List[Emulator]
# number of samples per emulator
# maybe this should belong to the emulator?
n_samples: List[int]
def predict(self, x):
# compute ensemble mean and covariance
# from emulators as in above formula
return mean, covariance
def score(self, x):
# disagrement score for query-by-committe
# possibly should be in the active learning method instead
_, covariance = self.predict(x)
return torch.trace(covariance)
def fit(X, Y):
# fit all emulators in parallel
..
class AutoEmulate(Ensemble):
# inherits attributes/methods from Ensemble
def compare(self, ..):
# use ensemble method
self.fit(x)
# then do cross-validation, etc.
..
# NOTE: some ensemble methods might not be used
def other_autoemulate_methods(..):
..
Parallelization of emulator training relates to #421.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status