PyTorch does a great job of constructing an object-oriented interface for neural networks. Methods like
net.cuda()
net.cpu()
net.half()
net.float()
net.to(dtype, device)
allow the user to think of a network as a single object which can be altered, rather than as a series of smaller objects that must be moved individually.
However, a similar interface does not exist for datatypes. We often want to keep track of multiple related pieces of information about a single example, such as an class of an image, bounding boxes, or the unencoded string text paired with the encoded + tokenized int64 indices.
There are also a number of operations that must take into account all subcomponents, including:
- Moving to/from devices
- Changing dtype
- Slicing
- Collating
A common approach is to treat these data objects as dictionaries, and use some python treemap function to apply an operation to all subcomponents. However, this breaks multiple coding principles. Dictionaries are not typed, so can cause errors when the variables in the dictionary are unclear. Some operations, such as slicing, are complicated and may require different indexing for different components. More broadly, this breaks the object- oriented approach of other PyTorch componenents.
We want a method that supports all of:
- Moving to/from devices
- Changing dtype
- Slicing
- Collating
in an intuitive, user-friendly manner. Additionally, like PyTorch modules, datatypes should be composable.
This module defines the Datablock
class, which can turn any frozen dataclass into a composable, torch-friendly databar. Example:
@dataclass(frozen=True, repr=False)
class Foo(Datablock):
id: str
sequence: str = field(metadata={"dim": -1}) # indicate the dimension along which slicing should occur
tensor: torch.Tensor = field(metadata={"pad": -1, "dim": -1}) # indicate the dimension for slicing and the pad value for collating
@classmethod
def from_sequence(cls, id: str, sequence: str):
tensor = torch.tensor(encode(sequence), dtype=torch.int64)
return cls(id=id, sequence=sequence, tensor=tensor)
Now we can create a Foo and operate on it in various ways
>>> import string
>>> encode = lambda x: torch.tensor([string.ascii_uppercase.index(tok] for tok in x], dtype=torch.int64)
>>> header = "test"
>>> seq = "ABCDE"
>>> foo = Foo.from_sequence(id, seq)
>>> foo
Foo(
id=test,
sequence=ABCDE,
tensor=tensor([0, 1, 2, 3, 4]),
batch_size=0,
dtype=torch.float32,
device=cpu,
)
>>> foo[:2]
Foo(
id=test,
sequence=AB,
tensor=tensor([0, 1]),
batch_size=0,
dtype=torch.float32,
device=cpu,
)
# Lists will work, along with numpy arrays, torch tensors, negative indexing, etc.
>>> foo[[0, 1, 3]]
Foo(
id=test,
sequence=ABD,
tensor=tensor([0, 1, 3]),
batch_size=0,
dtype=torch.float32,
device=cpu,
)
>>> foo.cuda()
Foo(
id=test,
sequence=ABCDE,
tensor=tensor([0, 1, 2, 3, 4], device="cuda:0"),
batch_size=0,
dtype=torch.float32,
device=cuda,
)
# Moving to/from cuda and slicing will still work after collating
>>> Foo.collate([foo[:2], foo[2:]])
Foo(
id=["test", "test"],
sequence=["AB", "CDE"],
tensor=tensor(
[[0, 1, -1] # pad is automatically -1 b/c of metadata specified in declaration
[2, 3, 4]]
),
batch_size=2,
dtype=torch.float32,
device=cpu,
)
It's also possible to compose objects in a straightforward manner
@dataclass(frozen=True, repr=False)
class Bar(Datablock):
foo: Foo
baz: torch.Tensor = field(metadata={"dim": -1})
>>> bar = Bar(
foo,
baz=torch.arange(5),
)
# All methods (slicing, move to/from device/dtype) will still work.
>>> bar[:2]
Bar(
foo=Foo(
id=test,
sequence=AB,
tensor=tensor([0, 1]),
batch_size=0,
dtype=torch.float32,
device=cpu,
),
baz=torch.tensor([0, 1]),
batch_size=0,
dtype=torch.float32,
device=cpu,
)
To define a lazy property, use the built in functools.cached_property
decorator:
import time
import torch
from dataclasses import dataclass
import functools
from datablocks import Datablock
@dataclass(frozen=True, repr=False)
class Lazy(Datablock):
a: torch.Tensor = torch.randn(3, 5)
@functools.cached_property
def hello(self) -> str:
time.sleep(10) # to mimic a long computation
return "world"
Now, accessing this property multiple times will show the speedup:
>>> lazy = Lazy()
>>> lazy.hello
"world" # 10 s
>>> lazy.hello
"world" # <0.1 µs
>>> lazy = lazy.cuda()
>>> lazy.hello
"world" # 10 s, shift to cuda removes cache