-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathmodel.py
245 lines (205 loc) · 8.27 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import re
from typing import Optional, List
import torch
from torch.autograd import grad
from torch import nn
from torch_scatter import scatter
from pytorch_lightning.utilities import rank_zero_warn
from torchmdnet.models import output_modules
from torchmdnet.models.wrappers import AtomFilter
from torchmdnet import priors
def create_model(args, prior_model=None, mean=None, std=None):
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
num_rbf=args["num_rbf"],
rbf_type=args["rbf_type"],
trainable_rbf=args["trainable_rbf"],
activation=args["activation"],
neighbor_embedding=args["neighbor_embedding"],
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
max_num_neighbors=args["max_num_neighbors"],
)
# representation network
if args["model"] == "graph-network":
from torchmdnet.models.torchmd_gn import TorchMD_GN
# TODO: remove legacy
args["aggr"] = args["aggr"] if "aggr" in args else "add"
is_equivariant = False
representation_model = TorchMD_GN(
num_filters=args["embedding_dimension"], aggr=args["aggr"], **shared_args
)
elif args["model"] == "transformer":
from torchmdnet.models.torchmd_t import TorchMD_T
is_equivariant = False
representation_model = TorchMD_T(
attn_activation=args["attn_activation"],
num_heads=args["num_heads"],
distance_influence=args["distance_influence"],
**shared_args,
)
elif args["model"] == "equivariant-transformer":
from torchmdnet.models.torchmd_et import TorchMD_ET
is_equivariant = True
representation_model = TorchMD_ET(
attn_activation=args["attn_activation"],
num_heads=args["num_heads"],
distance_influence=args["distance_influence"],
**shared_args,
)
else:
raise ValueError(f'Unknown architecture: {args["model"]}')
# atom filter
if not args["derivative"] and args["atom_filter"] > -1:
representation_model = AtomFilter(representation_model, args["atom_filter"])
elif args["atom_filter"] > -1:
raise ValueError("Derivative and atom filter can't be used together")
# prior model
if args["prior_model"] and prior_model is None:
assert "prior_args" in args, (
f"Requested prior model {args['prior_model']} but the "
f'arguments are lacking the key "prior_args".'
)
assert hasattr(priors, args["prior_model"]), (
f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}'
)
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = getattr(priors, args["prior_model"])(**args["prior_args"])
# create output network
output_prefix = "Equivariant" if is_equivariant else ""
output_model = getattr(output_modules, output_prefix + args["output_model"])(
args["embedding_dimension"], args["activation"]
)
# combine representation and output network
model = TorchMD_Net(
representation_model,
output_model,
prior_model=prior_model,
reduce_op=args["reduce_op"],
mean=mean,
std=std,
derivative=args["derivative"],
)
return model
def load_model(filepath, args=None, device="cpu", **kwargs):
ckpt = torch.load(filepath, map_location="cpu")
if args is None:
args = ckpt["hyper_parameters"]
for key, value in kwargs.items():
assert key in args, "Unknown hyperparameter '{key}'."
args[key] = value
model = create_model(args)
# TODO: remove `("" if k.startswith("model.network.") else "network.") + `
# in the future. This is legacy code for loading old checkpoint files.
state_dict = {
("" if k.startswith("model.network.") else "network.")
+ re.sub(r"^model\.", "", k): v
for k, v in ckpt["state_dict"].items()
}
model.load_state_dict(state_dict)
return model.to(device)
class TorchMD_Net(nn.Module):
def __init__(
self,
representation_model,
output_model,
prior_model=None,
reduce_op="add",
mean=None,
std=None,
derivative=False,
):
super(TorchMD_Net, self).__init__()
# TODO: merge content from PredictionNetwork into TorchMD_Net
# once https://github.com/pytorch/pytorch/issues/63145 is fixed
self.network = PredictionNetwork(
representation_model, output_model, prior_model, reduce_op, mean, std
)
self.derivative = derivative
def reset_parameters(self):
self.network.reset_parameters()
def forward(self, z, pos, batch: Optional[torch.Tensor] = None):
assert z.dim() == 1 and z.dtype == torch.long
batch = torch.zeros_like(z) if batch is None else batch
if self.derivative:
pos.requires_grad_(True)
out = self.network(z, pos, batch)
# compute gradients with respect to coordinates
if self.derivative:
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(out)]
dy = grad(
[out],
[pos],
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
)[0]
if dy is None:
raise RuntimeError("Autograd returned None for the force prediction.")
return out, -dy
# TODO: return only `out` once Union typing works with TorchScript:
# https://github.com/pytorch/pytorch/pull/53180
# Can't return `None` for the derivative because torch.jit.trace expects all
# return values to be Tensors
return out, torch.empty(0)
class PredictionNetwork(nn.Module):
"""Temporary module to separate running the model and computing forces.
Calling torch.autograd.grad in a JIT compiled module currently doesn't work.
In order to compile the model using torch.jit.trace, the grad call has to be in a
separate nn.Module, which is not compiled. This module can be merged into TorchMD_Net
once https://github.com/pytorch/pytorch/issues/63145 is fixed.
"""
def __init__(
self,
representation_model,
output_model,
prior_model=None,
reduce_op="add",
mean=None,
std=None,
):
super(PredictionNetwork, self).__init__()
self.representation_model = representation_model
self.output_model = output_model
self.prior_model = prior_model
if not output_model.allow_prior_model and prior_model is not None:
self.prior_model = None
rank_zero_warn(
(
"Prior model was given but the output model does "
"not allow prior models. Dropping the prior model."
)
)
self.reduce_op = reduce_op
mean = torch.scalar_tensor(0) if mean is None else mean
self.register_buffer("mean", mean)
std = torch.scalar_tensor(1) if std is None else std
self.register_buffer("std", std)
self.reset_parameters()
def reset_parameters(self):
self.representation_model.reset_parameters()
self.output_model.reset_parameters()
if self.prior_model is not None:
self.prior_model.reset_parameters()
def forward(self, z, pos, batch):
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(z, pos, batch=batch)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
# scale by data standard deviation
if self.std is not None:
x = x * self.std
# apply prior model
if self.prior_model is not None:
x = self.prior_model(x, z, pos, batch)
# aggregate atoms
out = scatter(x, batch, dim=0, reduce=self.reduce_op)
# shift by data mean
if self.mean is not None:
out = out + self.mean
# apply output model after reduction
out = self.output_model.post_reduce(out)
return out