Skip to content

Conversation

radka-j
Copy link
Member

@radka-j radka-j commented Oct 6, 2025

Closes #748
Closes #757
Closes #874

Overview:

  • emulators save all their input args so that all input values can be retrieved
  • replace any **kwargs with optional keyword arguments to match use (they are mostly used to handle for scheduler kwargs)
  • update HMW so that user can pass emulator as well as result
  • re-initialize emulators when refitting in AL

TODOs:

  • add tests
  • should the reinitialisation code be a method to avoid duplication of code?
  • check attribute access error issue in HMW (lines 383 and 388)

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Base automatically changed from iss867/update_gp_factory to main October 6, 2025 13:19
@sgreenbury
Copy link
Collaborator

Just adding a note here as ran into this when working with a GP subclass for the error quantification. This call:

model_class = get_emulator_class(result.model_name)

fails since:
emulator_cls = EMULATOR_REGISTRY.get(
name.lower()
) or EMULATOR_REGISTRY_SHORT_NAME.get(name.lower())

doesn't also look at:
GP_REGISTRY = {
"GaussianProcess": GaussianProcess,
"GaussianProcessCorrelated": GaussianProcessCorrelated,
}

@radka-j - adding here as it might be addressed by the upcoming changes to this API? But if not happy to open a new issue to look at this. An option could also be to revisit having a central registry class to handle this uniformly.

@radka-j
Copy link
Member Author

radka-j commented Oct 6, 2025

@sgreenbury I don't think we should ever use the GaussianProcess or GaussianProcessCorrelated classes so this to me feels like correct behaviour. If we want a GP class for an RBF + constant kernel we should add that specifically to the registry.

@sgreenbury
Copy link
Collaborator

It was the GP context (passing a create_gp_subclass instance to AutoEmulate) I ran into this issue and a workaround might have been to also look at GP_REGISTRY since this maintains a registry of all GPs including the created subclasses.

But thinking more about it, it affects any subclass used by AutoEmulate currently if reinitialize is called, e.g. in the advanced tutorial:

class SimpleFNN(PyTorchBackend):
    ...
ae = AutoEmulate(x, y, models=[SimpleFNN])
ae.fit_from_reinitialized(x, y)

since SimpleFNN is constructed at runtime the class is not found in the lists of emulators.

I think if the emulator becomes the entity that does the refitting in this PR then a global emulator registry including all custom subclasses would not be needed for this but might still be useful?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Revisit API for scheduler params Re-initialize emulator when refitting in AL Use fit_from_initialised when refitting emulators in HMW
2 participants