Skip to content

Commit fa102e6

Browse files
author
Raimondas Galvelis
authored
Update the test models and CI dependencies (#73)
* Update dependencies * Downgrade GCC * Reverse enginier a model generator * Regenerate the model with PyTorch 1.11 * Update to PyTorch 1.11 * Update to Python 3.10 * Empty line * Simplify the test models
1 parent 84f7d88 commit fa102e6

File tree

6 files changed

+31
-7
lines changed

6 files changed

+31
-7
lines changed

.github/workflows/CI.yml

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@ jobs:
2323
matrix:
2424
include:
2525
# Oldest supported versions
26-
- name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.7)
26+
- name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.11)
2727
os: ubuntu-18.04
2828
cuda-version: "10.2.89"
29-
gcc-version: "9.4.*"
29+
gcc-version: "8.5.*"
3030
nvcc-version: "10.2"
3131
python-version: "3.7"
32-
pytorch-version: "1.7.*"
32+
pytorch-version: "1.11.*"
3333

3434
# Latest supported versions
35-
- name: Linux (CUDA 11.2, Python 3.9, PyTorch 1.10)
35+
- name: Linux (CUDA 11.2, Python 3.10, PyTorch 1.11)
3636
os: ubuntu-18.04
3737
cuda-version: "11.2.2"
38-
gcc-version: "11.2.*"
38+
gcc-version: "10.3.*"
3939
nvcc-version: "11.2"
40-
python-version: "3.9"
41-
pytorch-version: "1.10.*"
40+
python-version: "3.10"
41+
pytorch-version: "1.11.*"
4242

4343
- name: MacOS (Python 3.9, PyTorch 1.9)
4444
os: macos-11

tests/central.pt

-267 Bytes
Binary file not shown.

tests/forces.pt

0 Bytes
Binary file not shown.

tests/generate.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch as pt
2+
3+
class Central(pt.nn.Module):
4+
def forward(self, pos):
5+
return pos.pow(2).sum()
6+
7+
class Forces(pt.nn.Module):
8+
def forward(self, pos):
9+
return pos.pow(2).sum(), -2 * pos
10+
11+
class Global(pt.nn.Module):
12+
def forward(self, pos, k):
13+
return k * pos.pow(2).sum()
14+
15+
class Periodic(pt.nn.Module):
16+
def forward(self, pos, box):
17+
box = box.diagonal().unsqueeze(0)
18+
pos = pos - (pos / box).floor() * box
19+
return pos.pow(2).sum()
20+
21+
pt.jit.script(Central()).save('central.pt')
22+
pt.jit.script(Forces()).save('forces.pt')
23+
pt.jit.script(Global()).save('global.pt')
24+
pt.jit.script(Periodic()).save('periodic.pt')

tests/global.pt

192 Bytes
Binary file not shown.

tests/periodic.pt

-784 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)