Skip to content

Latest commit

 

History

History

bind_module

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

Binding a module

ArgBind allows you bind an entire module, giving back an object that behaves similarly to the original module, but has all of the sub-modules and functions bound to the parser. This is useful if you're working with a big package, but don't want to sprinkle the package code with decorators for all the functions you care about. Instead, one can do it all at once and get a package with everything bound as needed.

Let's start by using ArgBind to bind all the methods in torch.optim:

import torch
import argbind

optim = argbind.bind_module(
    torch.optim, 
    filter_fn=lambda fn: hasattr(fn, "step")
)
args = {
    "lr": 2e-4,
    "args.debug": True,
}

net = torch.nn.Linear(1, 1)
for fn_name in dir(optim):
    if fn_name.startswith("_") or fn_name == "Optimizer":
        continue
    fn = getattr(optim, fn_name)
    args[f"{fn_name}.lr"] = args["lr"]
    with argbind.scope(args):
        fn(net.parameters())

A few things to note, bind_module takes (optionally) a filter_fn, which takes in the object, and returns a bool which indicates whether or not it should be bound. You can use this to filter out any functions or classes that you don't want bound. Here we use it to bind only classes which have a step attribute (these are the optimizers).

Then we just write our script normally like any other ArgBind script.

N.B.: bind_module only goes ONE level deep. It does not recursively apply itself to bind submodules. This could happen in the future, but the logic must be done carefully to avoid loops cause by circular imports.

For comparison, here is how the code would look without binding the entire module, to achieve the same output:

import torch
import argbind

# Binding all the optimizers
ASGD = argbind.bind(torch.optim.ASGD)
Adadelta = argbind.bind(torch.optim.Adadelta)
Adagrad = argbind.bind(torch.optim.Adagrad)
AdamW = argbind.bind(torch.optim.AdamW)
Adamax = argbind.bind(torch.optim.Adamax)
LBFGS = argbind.bind(torch.optim.LBFGS)
NAdam = argbind.bind(torch.optim.NAdam)
RAdam = argbind.bind(torch.optim.RAdam)
RMSprop = argbind.bind(torch.optim.RMSprop)
Rprop = argbind.bind(torch.optim.Rprop)
SGD = argbind.bind(torch.optim.SGD)
SparseAdam = argbind.bind(torch.optim.SparseAdam)

optimizers = [
    ASGD, Adadelta, Adagrad, AdamW, Adamax, LBFGS,
    NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
]

class holder:
    def __init__(self):
        for o in optimizers:
            setattr(self, o.__name__, o)
optim = holder()

args = {
    "lr": 2e-4,
    "args.debug": True,
}

net = torch.nn.Linear(1, 1)

for fn_name in dir(optim):
    if fn_name == "Optimizer":
        continue
    fn = getattr(optim, fn_name)
    if hasattr(fn, "step"):
        args[f"{fn_name}.lr"] = args["lr"]
        with argbind.scope(args):
            fn(net.parameters())

In this script, every single function must be bound one by one, which is rather tedious. bind_module is more succinct.