-
-
Notifications
You must be signed in to change notification settings - Fork 644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Partial instantiate/call #1283
Comments
Can you explain what problem you are trying to solve in more details? |
Sure, below is an example similar to what I am facing in my current project. Consider this simple CNN class: # main.py
from torch import nn
from typing import Sequence
class CNN(nn.Module):
def __init__(self,
input_channels: int,
hidden_sizes: Sequence[int],
conv_layer = nn.Conv2d,
norm_layer = nn.BatchNorm2d,
activation_fn = nn.ReLU):
super().__init__()
layers = []
for h_size in hidden_sizes:
layers.append(conv_layer(input_channels, h_size, kernel_size=(3, 3)))
layers.append(norm_layer(h_size))
layers.append(activation_fn())
input_channels = h_size
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x).mean((-1, -2)) I would like to be able to change what type of With partial instantiation, I can specify my config as follows: # config.yaml
model:
_target_: main.CNN
input_channels: 3
hidden_sizes: [8, 16, 32, 10]
norm_layer:
_target_: torch.nn.SyncBatchNorm
_partial_: True
momentum: 0.2
eps: 0.01 Then in my main, # main.py
model = instantiate(cfg.model) Which would be the equivalent of # main.py
from functools import partial
model = CNN(
input_channels=3,
hidden_sizes=[8, 16, 32, 10],
norm_layer=partial(torch.nn.SyncBatchNorm, momentum=0.2, eps=0.1)
) Does this make sense? Is there any way to accomplish this currently? |
Yes, it takes some acrobatics but you can already achieve it in a relatively clean way: from typing import Any
from functools import partial
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
# Registering a resolver that can return a callabale method
# This also works on classes, but you can also register get_class which is almost identical.
OmegaConf.register_resolver("get_method", hydra.utils.get_method)
cfg = OmegaConf.create({"method": "${get_method:math.sin}"})
print("sin(1)", cfg.method(1))
class Foo:
def __init__(self) -> None:
self.a = 1
self.b = 1
def run(self) -> int:
return self.a + self.b
class UsefulFoo(Foo):
def __init__(self, a: int, b: int) -> None:
self.a = a
self.b = b
prt = partial(UsefulFoo, 10)
print("20 + 10 = ", prt(20).run())
class Bar:
def __init__(self, foo: Any = Foo, y: int = 1) -> None:
self.result = foo().run()
def partial2(func: Any, *args, **kwargs) -> Any:
"""
normal partial requires func to be passed as a positional argument.
This is not currently supported by instantiate, this function bridges that gap
"""
return partial(func, *args, **kwargs)
print("direct instantiate", Bar(y=2).result)
print(
"instantiate",
instantiate({"_target_": "__main__.Bar"}).result,
)
print(
"instantiate partial",
instantiate(
{
"_target_": "__main__.Bar",
"foo": {
"_target_": "__main__.partial2",
"func": "${get_method:__main__.UsefulFoo}",
"a": 10,
"b": 20,
},
}
).result,
) Output:
|
Thanks, this solves my problem! I think it would be clean to have this integrated into the library via |
Great! If this issue generates a lot of interest I can consider adding explicit support later. I hope my answer can help others in a similar situation. Closing for now. |
Will re-evaluate for 1.2. |
FYI: Hydra 1.1 instantiate now supports positional arguments. from functools import partial
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
# python version
basetwo = partial(int, base=2)
assert basetwo("10010") == 18
# instantiate version
# Registering a resolver that can return a callabale method
# This also works on classes, but you can also register get_class which is almost identical.
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
basetwo2 = instantiate({
"_target_": "functools.partial",
"_args_": ["${get_method:builtins.int}"],
"base": 2,
}
)
assert basetwo2("10010") == 18 |
You can make this very clean with a wrapper around
I tried a bunch of different ways of doing this with interpolations and more tricks with the base config to try to make it even cleaner but this was the only thing I could get working. |
FWIW we'd love if Using @Queuecumber 's solution at the moment, and understand if you don't end up including explicit partials (it's a little complicated) – but just wanted to mention that we'd love to see them too and happy to answer any questions about use case. |
Thanks for the excellent example @omry. Here's a PyTorch Lightning specific demo I created to illustrate @xvr-hlt's use-case for passing partial optimizers through for later use with model parameters from functools import partial
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch
class BoringModel(pl.LightningModule):
def __init__(self, optim_partial, in_feats=4, out_feats=2):
super().__init__()
# Hydra config components
self.optim_partial = optim_partial
self.in_feats = in_feats
self.out_feats = out_feats
# Control weight randomness
pl.seed_everything(1234)
self.layer = torch.nn.Linear(in_feats, out_feats)
def forward(self, x):
return self.layer(x)
def configure_optimizers(self):
return self.optim_partial(self.parameters())
# Plain python approach
model = BoringModel(optim_partial=partial(torch.optim.Adam, lr=1e-5, weight_decay=0.2))
optimizer = model.configure_optimizers()
# Partial instantiate approach
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
cfg = {
'_target_': '__main__.BoringModel',
'in_feats': 4,
'out_feats': 2,
'optim_partial': {
'_target_': 'functools.partial',
'_args_': ['${get_method: torch.optim.Adam}'],
'lr': 1e-5,
'weight_decay': 0.2}}
model2 = instantiate(cfg)
optimizer2 = model2.configure_optimizers()
# Equality comparison of all optimization hyperparameters + model parameters
for g, group in enumerate(optimizer2.param_groups):
for k, v in group.items():
if k == 'params':
for p, param in enumerate(v):
assert torch.all(param == optimizer.param_groups[g][k][p])
else:
assert v == optimizer.param_groups[g][k] |
First off, thanks for the wonderful library!!
🚀 Feature Request
I would like to be able to partially instantiate/call a class/function. For example:
Motivation
Currently, to make the above use-case work, I would do the following:
This is acceptable for code I write, but does not work well when I am trying to configure functions that libraries provide. Many libraries follow a more functional style (compared to PyTorch and TensorFlow), so losses/activations/metrics are provided as simple functions as opposed to callable objects. For example, Flax for JAX (and several other neural network libraries for JAX) defines all its activation functions and pooling layers as straightforward functions instead of classes, making partial instantiation crucial for configuration.
Also, code will often be more clear when there are more simple functions and fewer higher order functions/classes. Partial instantiation will prevent code from having too many of the latter.
Pitch
Describe the solution you'd like
Having an optional
_partial_
entry in the config (similar to_recursive_
and_convert_
) in my view is the most straightforward way to achieve this. By default this would beFalse
, and whenTrue
, partially instantiates/calls the class/function instead of actually instantiating/calling.Describe alternatives you've considered
Another option is to introduce two new methods:
hydra.utils.partial_instantiate
andhydra.utils.partial_call
. This removes the need for another config entry, and makes it more clear at the call-site what's going on. There is one major disadvantage: it's not clear how this would work with_recursive_=True
. Would all the recursive instantiations be partial? You probably don't want that. Will only the top level instantiation be partial? This would limit some use cases as well.For this reason, I think the
_partial_
entry makes the most sense.Are you willing to open a pull request? (See CONTRIBUTING)
Yes! I was planning to make an
_pop_is_partial
function (likepop_convert_mode
), then add the appropriatefunctools.partial
calls here.The text was updated successfully, but these errors were encountered: