File tree 6 files changed +31
-7
lines changed
6 files changed +31
-7
lines changed Original file line number Diff line number Diff line change @@ -23,22 +23,22 @@ jobs:
23
23
matrix :
24
24
include :
25
25
# Oldest supported versions
26
- - name : Linux (CUDA 10.2, Python 3.7, PyTorch 1.7 )
26
+ - name : Linux (CUDA 10.2, Python 3.7, PyTorch 1.11 )
27
27
os : ubuntu-18.04
28
28
cuda-version : " 10.2.89"
29
- gcc-version : " 9.4 .*"
29
+ gcc-version : " 8.5 .*"
30
30
nvcc-version : " 10.2"
31
31
python-version : " 3.7"
32
- pytorch-version : " 1.7 .*"
32
+ pytorch-version : " 1.11 .*"
33
33
34
34
# Latest supported versions
35
- - name : Linux (CUDA 11.2, Python 3.9 , PyTorch 1.10 )
35
+ - name : Linux (CUDA 11.2, Python 3.10 , PyTorch 1.11 )
36
36
os : ubuntu-18.04
37
37
cuda-version : " 11.2.2"
38
- gcc-version : " 11.2 .*"
38
+ gcc-version : " 10.3 .*"
39
39
nvcc-version : " 11.2"
40
- python-version : " 3.9 "
41
- pytorch-version : " 1.10 .*"
40
+ python-version : " 3.10 "
41
+ pytorch-version : " 1.11 .*"
42
42
43
43
- name : MacOS (Python 3.9, PyTorch 1.9)
44
44
os : macos-11
Original file line number Diff line number Diff line change
1
+ import torch as pt
2
+
3
+ class Central (pt .nn .Module ):
4
+ def forward (self , pos ):
5
+ return pos .pow (2 ).sum ()
6
+
7
+ class Forces (pt .nn .Module ):
8
+ def forward (self , pos ):
9
+ return pos .pow (2 ).sum (), - 2 * pos
10
+
11
+ class Global (pt .nn .Module ):
12
+ def forward (self , pos , k ):
13
+ return k * pos .pow (2 ).sum ()
14
+
15
+ class Periodic (pt .nn .Module ):
16
+ def forward (self , pos , box ):
17
+ box = box .diagonal ().unsqueeze (0 )
18
+ pos = pos - (pos / box ).floor () * box
19
+ return pos .pow (2 ).sum ()
20
+
21
+ pt .jit .script (Central ()).save ('central.pt' )
22
+ pt .jit .script (Forces ()).save ('forces.pt' )
23
+ pt .jit .script (Global ()).save ('global.pt' )
24
+ pt .jit .script (Periodic ()).save ('periodic.pt' )
You can’t perform that action at this time.
0 commit comments