Skip to content

Commit

Permalink
Stress Calculation in ForceRegressionTask (#302)
Browse files Browse the repository at this point in the history
* feat: lightning upgrade

* fix: add back delted devset files

* feat: adding stress calculation to ForceRegressionTask

* fix: doing frame avarage things proper for stress, updating tests.

* fix: adding back commented test

* fix: deleteing commented lines and revert pyproject.toml

* fix: remove added trainer arg

* fix: latest pyproject.toml

* fix: update docstring with reference
  • Loading branch information
melo-gonzo authored Oct 10, 2024
1 parent b66d4c3 commit 88218c9
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 58 deletions.
2 changes: 0 additions & 2 deletions matsciml/datasets/transforms/representations/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,6 @@ def epilogue(self, data: DataDict) -> None:
"pc_features",
"distance_matrix",
"atomic_numbers",
"src_nodes",
"dst_nodes",
"sizes",
]:
try:
Expand Down
231 changes: 176 additions & 55 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,9 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = None,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = None,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
lr: float = 1e-4,
Expand Down Expand Up @@ -1325,10 +1324,9 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.MSELoss,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = nn.MSELoss,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
**kwargs: Any,
Expand Down Expand Up @@ -1446,10 +1444,9 @@ def __init__(
encoder: Optional[nn.Module] = None,
encoder_class: Optional[Type[nn.Module]] = None,
encoder_kwargs: Optional[Dict[str, Any]] = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.MSELoss,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = nn.MSELoss,
loss_coeff: Optional[Dict[str, Any]] = None,
task_keys: Optional[List[str]] = None,
output_kwargs: Dict[str, Any] = {},
Expand Down Expand Up @@ -1676,10 +1673,9 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.BCEWithLogitsLoss,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = nn.BCEWithLogitsLoss,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
**kwargs,
Expand Down Expand Up @@ -1755,13 +1751,13 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = None,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = None,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
embedding_reduction_type: str = "sum",
compute_stress: bool = False,
**kwargs,
) -> None:
if not loss_func:
Expand All @@ -1786,6 +1782,7 @@ def __init__(
self.save_hyperparameters(ignore=["encoder", "loss_func"])
# have to enable double backprop
self.automatic_optimization = False
self.compute_stress = compute_stress

def _make_output_heads(self) -> nn.ModuleDict:
# this task only utilizes one output head
Expand All @@ -1800,6 +1797,7 @@ def forward(
self,
batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
# virial/stress computation inspired from https://github.com/ACEsuit/mace/blob/main/mace/modules/utils.py
# for ease of use, this task will always compute forces
with dynamic_gradients_context(True, self.has_rnn):
# first ensure that positions tensor is backprop ready
Expand All @@ -1824,6 +1822,39 @@ def forward(
# no frame averaging architecture yet for point clouds
fa_rot = None
fa_pos = None

if self.compute_stress:
num_graphs = batch["graph"].batch_size
src_nodes = torch.cat(batch["src_nodes"])
cell = batch.get("cell")
unit_offsets = batch.get("unit_offsets")
batch_indices = torch.concat(
[
torch.tensor([idx] * batch["natoms"][idx].long())
for idx in range(num_graphs)
]
)
pos, shifts, displacement, symmetric_displacement = (
self.get_symmetric_displacement(
pos, unit_offsets, cell, src_nodes, num_graphs, batch_indices
)
)
batch["displacement"] = displacement

if "graph" in batch:
graph.pos = pos
if hasattr(graph, "ndata"):
graph.ndata["pos"] = pos

if fa_pos is not None:
for index in range(len(fa_pos)):
fa_pos[index].requires_grad_(True)
fa_pos[index] = fa_pos[index] + torch.einsum(
"be,bec->bc",
fa_pos[index],
symmetric_displacement[batch_indices],
)

if pos is None:
raise ValueError(
"No atomic positions were found in batch - neither as standalone tensor nor graph.",
Expand All @@ -1844,26 +1875,24 @@ def forward(
embeddings = batch.get("embeddings")
else:
embeddings = self.encoder(batch)
natoms = batch.get("natoms", None)
outputs = self.process_embedding(
embeddings, pos, fa_rot, fa_pos, natoms, graph
)

outputs = self.process_embedding(embeddings, batch, pos, fa_rot, fa_pos)
return outputs

def process_embedding(
self,
embeddings: Embeddings,
batch: BatchDict,
pos: torch.Tensor,
fa_rot: None | torch.Tensor = None,
fa_pos: None | torch.Tensor = None,
natoms: None | torch.Tensor = None,
graph: None | AbstractGraph = None,
) -> dict[str, torch.Tensor]:
graph = batch.get("graph")
natoms = batch.get("natoms")

outputs = {}
# compute node-level contributions to the energy
node_energies = self.output_heads["energy"](embeddings.point_embedding)
# figure out how we're going to reduce node level energies
# depending on the representation and/or the graph framework
if graph is not None:
if isinstance(graph, dgl.DGLGraph):
graph.ndata["node_energies"] = node_energies
Expand All @@ -1872,17 +1901,18 @@ def readout(node_energies: torch.Tensor):
return dgl.readout_nodes(
graph, "node_energies", op=self.embedding_reduction_type
)

else:
# assumes a batched pyg graph
batch = getattr(graph, "batch", None)
if batch is None:
batch = torch.zeros_like(graph.atomic_numbers)
batch_indices = getattr(graph, "batch", None)
if batch_indices is None:
batch_indices = torch.zeros_like(graph.atomic_numbers)
from torch_geometric.utils import scatter

def readout(node_energies: torch.Tensor):
return scatter(
node_energies,
batch,
batch_indices,
dim=-2,
reduce=self.embedding_reduction_type,
)
Expand All @@ -1901,61 +1931,115 @@ def energy_and_force(
energy = readout(node_energies)
if energy.ndim == 1:
energy.unsqueeze(-1)
# now use autograd for force calculation
force = (
-1
* torch.autograd.grad(
energy,
pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
# now use autograd for force and virials calculation
inputs = [pos]

if self.compute_stress:
displacement = batch["displacement"]
inputs.append(displacement)

# outputs will be (force, virials)
outputs = torch.autograd.grad(
outputs=energy,
inputs=inputs,
grad_outputs=[torch.ones_like(energy)],
create_graph=True,
)
return energy, force

if self.compute_stress:
virials = outputs[1]
stress = torch.zeros_like(displacement)
cell = batch["cell"].view(-1, 3, 3)
volume = torch.linalg.det(cell).abs().unsqueeze(-1)
stress = virials / volume.view(-1, 1, 1)
stress = torch.where(
torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)
)
stress = -1 * stress
else:
stress = None
virials = None
force = -1 * outputs[0]
return energy, force, stress, virials

# not using frame averaging
if fa_pos is None:
energy, force = energy_and_force(pos, node_energies, readout)
energy, force, stress, virials = energy_and_force(
pos, node_energies, readout
)
else:
energy = []
force = []
stress = []
virials = []
for idx, pos in enumerate(fa_pos):
frame_embedding = node_energies[:, idx, :]
frame_energy, frame_force = energy_and_force(
pos, frame_embedding, readout
frame_energy, frame_force, frame_stress, frame_virials = (
energy_and_force(pos, frame_embedding, readout)
)
force.append(frame_force)
energy.append(frame_energy.unsqueeze(-1))
stress.append(frame_stress)
virials.append(frame_virials)

# check to see if we are frame averaging
if fa_rot is not None:
all_forces = []
all_stress = []
all_virials = []
# loop over each frame prediction, and transform to guarantee
# equivariance of frame averaging method
natoms = natoms.squeeze(-1).to(int)
for frame_idx, frame_rot in enumerate(fa_rot):
repeat_rot = torch.repeat_interleave(
force_repeat_rot = torch.repeat_interleave(
frame_rot,
natoms,
dim=0,
).to(self.device)
rotated_forces = (
force[frame_idx].view(-1, 1, 3).bmm(repeat_rot.transpose(1, 2))
force[frame_idx]
.view(-1, 1, 3)
.bmm(force_repeat_rot.transpose(1, 2))
)
all_forces.append(rotated_forces)
if self.compute_stress:
for frame_idx, frame_rot in enumerate(fa_rot):
rotated_stress = stress[frame_idx].bmm(frame_rot.transpose(1, 2))
rotated_virials = virials[frame_idx].bmm(frame_rot.transpose(1, 2))
# adding dim to concat on
all_stress.append(rotated_stress.unsqueeze(1))
all_virials.append(rotated_virials.unsqueeze(1))

# combine all the force and energy data into a single tensor
# using frame averaging, the expected shapes after concatenation are:
# force - [num positions, num frames, 3]
# energy - [batch size, num frames, 1]
# stress/virials - [batch size, 3, 3]
force = torch.cat(all_forces, dim=1)
energy = torch.cat(energy, dim=1)
if self.compute_stress:
stress = torch.cat(all_stress, dim=1)
virials = torch.cat(all_virials, dim=1)
# reduce outputs to what are expected shapes
outputs["force"] = reduce(
force,
"n ... d -> n d",
self.embedding_reduction_type,
d=3,
)
if self.compute_stress:
outputs["stress"] = reduce(
stress,
"b ... n d -> b n d",
self.embedding_reduction_type,
d=3,
)
outputs["virials"] = reduce(
virials,
"b ... n d -> b n d",
self.embedding_reduction_type,
d=3,
)
# this may not do anything if we aren't frame averaging
# since the reduction is also done in the energy_and_force call
outputs["energy"] = reduce(
Expand Down Expand Up @@ -2178,6 +2262,45 @@ def _make_normalizers(self) -> dict[str, Normalizer]:
)
return normalizers

@staticmethod
def get_symmetric_displacement(
positions: torch.Tensor,
unit_offsets: torch.Tensor,
cell: Optional[torch.Tensor],
src_nodes: torch.Tensor,
num_graphs: int,
batch: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""virial/stress computation inspired from:
https://github.com/ACEsuit/mace/blob/main/mace/modules/utils.py
https://github.com/mir-group/nequip
"""
if cell is None:
cell = torch.zeros(
num_graphs * 3,
3,
dtype=positions.dtype,
device=positions.device,
)
displacement = torch.zeros(
(num_graphs, 3, 3),
dtype=positions.dtype,
device=positions.device,
)
displacement.requires_grad_(True)
symmetric_displacement = 0.5 * (displacement + displacement.transpose(-1, -2))
positions = positions + torch.einsum(
"be,bec->bc", positions, symmetric_displacement[batch]
)
cell = cell.view(-1, 3, 3)
cell = cell + torch.matmul(cell, symmetric_displacement)
shifts = torch.einsum(
"be,bec->bc",
unit_offsets,
cell[batch[src_nodes]],
)
return positions, shifts, displacement, symmetric_displacement


@registry.register_task("GradFreeForceRegressionTask")
class GradFreeForceRegressionTask(ScalarRegressionTask):
Expand All @@ -2186,10 +2309,9 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.MSELoss,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = nn.MSELoss,
output_kwargs: dict[str, Any] = {},
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -2326,10 +2448,9 @@ def __init__(
encoder: nn.Module | None = None,
encoder_class: type[nn.Module] | None = None,
encoder_kwargs: dict[str, Any] | None = None,
loss_func: type[nn.Module]
| nn.Module
| dict[str, nn.Module | type[nn.Module]]
| None = nn.CrossEntropyLoss,
loss_func: (
type[nn.Module] | nn.Module | dict[str, nn.Module | type[nn.Module]] | None
) = nn.CrossEntropyLoss,
output_kwargs: dict[str, Any] = {},
normalize_kwargs: dict[str, float] | None = None,
freeze_embedding: bool = False,
Expand Down
Loading

0 comments on commit 88218c9

Please sign in to comment.