Skip to content

Commit dbab6a6

Browse files
Restructuring (#19)
* Move python torchscripting script to utils folder. * Add README for utils (WIP). * Create examples directory with basic README listing examples to add. * Move ts_inference programs to examples directiry with a copy of CMakeLists and modify library CMake to no longer build and install them. * Move case specific pt2ts to examples, and replace with generic tool. * Restructure and re-write ResNet example. WIP (need multi-input merge from main.) * Update resnet example to work. * Complete Python-Fortran example. * Remove c and cpp from example 1 to other file. * Add Makefile build option to ResNet example. * Tidy ResNet README * Update resnet example to take model as command line argument. * Update README files as appropriate for these changes.
1 parent 448dc72 commit dbab6a6

23 files changed

+866
-194
lines changed

Diff for: README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ To build and install the library:
8585
make install
8686
```
8787
This will place the following directories at the install location:
88-
* `bin/` - contains example executables
8988
* `include/` - contains header and mod files
9089
* `lib64/` - contains cmake and `.so` files
9190
@@ -104,7 +103,7 @@ In order to use fortran-pytorch users will typically need to follow these steps:
104103
The trained PyTorch model needs to be exported to [TorchScript](https://pytorch.org/docs/stable/jit.html).
105104
This can be done from within your code using the [`jit.script`](https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script) or [`jit.trace`](https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace) functionalities from within python.
106105
107-
If you are not familiar with these we provide a tool `pt2ts.py` as part of this distribution which contains an easily adaptable script to save your PyTorch model as Torch Script.
106+
If you are not familiar with these we provide a tool [`pt2ts.py`](utils/pt2ts.py) as part of this distribution which contains an easily adaptable script to save your PyTorch model as TorchScript.
108107
109108
110109
### 2. Using the model from Fortran
@@ -209,7 +208,9 @@ export LD_LIBRARY_PATH = $LD_LIBRARY_PATH:<path/to/installation>/lib64
209208

210209

211210
## Examples
212-
To follow.
211+
212+
Examples of how to use this library are provided in the [examples directory](examples/).
213+
They demonstrate different functionalities and are provided with instructions to modify, build, and run as neccessary.
213214

214215
## License
215216

Diff for: examples/1_ResNet18/CMakeLists.txt

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
#policy CMP0076 - target_sources source files are relative to file where target_sources is run
3+
cmake_policy (SET CMP0076 NEW)
4+
5+
set(PROJECT_NAME ResNetExample)
6+
7+
project(${PROJECT_NAME} LANGUAGES Fortran)
8+
9+
# Build in Debug mode if not specified
10+
if(NOT CMAKE_BUILD_TYPE)
11+
set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE)
12+
endif()
13+
14+
find_package(FTorch)
15+
message(STATUS "Building with Fortran PyTorch coupling")
16+
17+
# Fortran example
18+
add_executable(resnet_infer_fortran resnet_infer_fortran.f90)
19+
target_link_libraries(resnet_infer_fortran PRIVATE FTorch::ftorch)

Diff for: examples/1_ResNet18/Makefile

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# compiler
2+
# Note - this should match the compiler that the library was built with
3+
FC = gfortran
4+
5+
# compile flags
6+
FCFLAGS = -O3 -I</path/to/installation>/include/ftorch
7+
8+
# link flags
9+
LDFLAGS = -L</path/to/installation>/lib64/ -lftorch
10+
11+
PROGRAM = resnet_infer_fortran
12+
SRC = resnet_infer_fortran.f90
13+
OBJECTS = $(SRC:.f90=.o)
14+
15+
all: $(PROGRAM)
16+
17+
$(PROGRAM): $(OBJECTS)
18+
$(FC) $(FCFLAGS) -o $@ $^ $(LDFLAGS)
19+
20+
%.o: %.f90
21+
$(FC) $(FCFLAGS) $(LDFLAGS) -c $<
22+
23+
clean:
24+
rm -f *.o *.mod
25+

Diff for: examples/1_ResNet18/README.md

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Example 1 - ResNet-18
2+
3+
This example provides a simple but complete demonstration of how to use the library.
4+
5+
## Description
6+
7+
A python file is provided that downloads the pretrained
8+
[ResNet-18](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html)
9+
model from [TorchVision](https://pytorch.org/vision/stable/index.html).
10+
11+
A modified version of the `pt2ts.py` tool saves this ResNet-18 to TorchScript.
12+
13+
A series of files `resnet_infer_<LANG>` then bind from other languages to run the
14+
TorchScript ResNet-18 model in inference mode.
15+
16+
## Dependencies
17+
18+
To run this example requires:
19+
20+
- cmake
21+
- fortran compiler
22+
- FTorch (installed as described in main package)
23+
- python3
24+
25+
## Running
26+
27+
To run this example install fortran-pytorch-lib as described in the main documentation.
28+
Then from this directory create a virtual environment an install the neccessary python
29+
modules:
30+
```
31+
python3 -m venv venv
32+
source venv/bin/activate
33+
pip install -r requirements.txt
34+
```
35+
36+
You can check that everything is working by running `resnet18.py`:
37+
```
38+
python3 resnet18.py
39+
```
40+
it should produce the result `tensor([[623, 499, 596, 111, 813]])`.
41+
42+
To save the pretrained ResNet-18 model to TorchScript run the modified version of the
43+
`pt2ts.py` tool :
44+
```
45+
python3 pt2ts.py
46+
```
47+
48+
At this point we no longer require python, so can deactivate the virtual environment:
49+
```
50+
deactivate
51+
```
52+
53+
To call the saved ResNet-18 model from fortran we need to compile the `resnet_infer`
54+
files.
55+
This can be done using the included `CMakeLists.txt` as follows:
56+
```
57+
mkdir build
58+
cd build
59+
cmake .. -DFTorchDIR=<path/to/your/installation/of/library> -DCMAKE_BUILD_TYPE=Release
60+
make
61+
```
62+
63+
To run the compiled code calling the saved ResNet-18 TorchScript from Fortran run the
64+
executable with an argument of the saved model file:
65+
```
66+
./resnet_infer_fortran ../saved_resnet18_model_cpu.pt
67+
```
68+
69+
Alternatively we can use `make`, instead of cmake, with the included Makefile.
70+
However, to do this you will need to modify `Makefile` to link to and include your
71+
installation of FTorch as described in the main documentation. Also check that the compiler is the same as the one you built the Library with.
72+
You will also likely need to add the location of the `.so` files to your `LD_LIBRARY_PATH`:
73+
```
74+
make
75+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:</path/to/library/installation>/lib64
76+
./resnet_infer_fortran saved_resnet18_model_cpu.pt
77+
```
78+
79+
## Further options
80+
81+
To explore the functionalities of this model:
82+
83+
- Try saving the model through tracing rather than scripting by modifying `pt2ts.py`

Diff for: examples/1_ResNet18/pt2ts.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Load a pytorch model and convert it to TorchScript."""
2+
from typing import Optional
3+
import torch
4+
5+
# FPTLIB-TODO
6+
# Add a module import with your model here:
7+
# This example assumes the model architecture is in an adjacent module `my_ml_model.py`
8+
import resnet18
9+
10+
11+
def script_to_torchscript(
12+
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt"
13+
) -> None:
14+
"""
15+
Save pyTorch model to TorchScript using scripting.
16+
17+
Parameters
18+
----------
19+
model : torch.NN.Module
20+
a pyTorch model
21+
filename : str
22+
name of file to save to
23+
"""
24+
print("Saving model using scripting...", end="")
25+
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
26+
scripted_model = torch.jit.script(model)
27+
# print(scripted_model.code)
28+
scripted_model.save(filename)
29+
print("done.")
30+
31+
32+
def trace_to_torchscript(
33+
model: torch.nn.Module,
34+
dummy_input: torch.Tensor,
35+
filename: Optional[str] = "traced_model.pt",
36+
) -> None:
37+
"""
38+
Save pyTorch model to TorchScript using tracing.
39+
40+
Parameters
41+
----------
42+
model : torch.NN.Module
43+
a pyTorch model
44+
dummy_input : torch.Tensor
45+
appropriate size Tensor to act as input to model
46+
filename : str
47+
name of file to save to
48+
"""
49+
print("Saving model using tracing...", end="")
50+
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
51+
traced_model = torch.jit.trace(model, dummy_input)
52+
# traced_model.save(filename)
53+
frozen_model = torch.jit.freeze(traced_model)
54+
## print(frozen_model.graph)
55+
## print(frozen_model.code)
56+
frozen_model.save(filename)
57+
print("done.")
58+
59+
60+
def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module:
61+
"""
62+
Load a TorchScript from file.
63+
64+
Parameters
65+
----------
66+
filename : str
67+
name of file containing TorchScript model
68+
"""
69+
model = torch.jit.load(filename)
70+
71+
return model
72+
73+
74+
if __name__ == "__main__":
75+
# =====================================================
76+
# Load model and prepare for saving
77+
# =====================================================
78+
79+
# FPTLIB-TODO
80+
# Load a pre-trained PyTorch model
81+
# Insert code here to load your model as `trained_model`.
82+
# This example assumes my_ml_model has a method `initialize` to load
83+
# architecture, weights, and place in inference mode
84+
trained_model = resnet18.initialize()
85+
86+
# Switch off specific layers/parts of the model that behave
87+
# differently during training and inference.
88+
# This may have been done by the user already, so just make sure here.
89+
trained_model.eval()
90+
91+
# =====================================================
92+
# Prepare dummy input and check model runs
93+
# =====================================================
94+
95+
# FPTLIB-TODO
96+
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size.
97+
# This example assumes two inputs of size (512x40) and (512x1)
98+
trained_model_dummy_input_1 = torch.ones(1, 3, 224, 224)
99+
100+
# FPTLIB-TODO
101+
# Uncomment the following lines to save for inference on GPU (rather than CPU):
102+
# device = torch.device('cuda')
103+
# trained_model = trained_model.to(device)
104+
# trained_model.eval()
105+
# trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device)
106+
# trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device)
107+
108+
# FPTLIB-TODO
109+
# Run model for dummy inputs
110+
# If something isn't working This will generate an error
111+
trained_model_dummy_output = trained_model(
112+
trained_model_dummy_input_1,
113+
)
114+
115+
# =====================================================
116+
# Save model
117+
# =====================================================
118+
119+
# FPTLIB-TODO
120+
# Set the name of the file you want to save the torchscript model to:
121+
saved_ts_filename = "saved_resnet18_model_cpu.pt"
122+
123+
# FPTLIB-TODO
124+
# Save the pytorch model using either scripting (recommended where possible) or tracing
125+
# -----------
126+
# Scripting
127+
# -----------
128+
script_to_torchscript(trained_model, filename=saved_ts_filename)
129+
130+
# -----------
131+
# Tracing
132+
# -----------
133+
# trace_to_torchscript(trained_model, trained_model_dummy_input, filename=saved_ts_filename)
134+
135+
print(f"Saved model to TorchScript in '{saved_ts_filename}'.")
136+
137+
# =====================================================
138+
# Check model saved OK
139+
# =====================================================
140+
141+
# Load torchscript and run model as a test
142+
# FPTLIB-TODO
143+
# Scale inputs as above and, if required, move inputs and mode to GPU
144+
trained_model_dummy_input_1 = 2.0 * trained_model_dummy_input_1
145+
trained_model_testing_output = trained_model(
146+
trained_model_dummy_input_1,
147+
)
148+
ts_model = load_torchscript(filename=saved_ts_filename)
149+
ts_model_output = ts_model(
150+
trained_model_dummy_input_1,
151+
)
152+
153+
if torch.all(ts_model_output.eq(trained_model_testing_output)):
154+
print("Saved TorchScript model working as expected in a basic test.")
155+
print("Users should perform further validation as appropriate.")
156+
else:
157+
raise RuntimeError(
158+
"Saved Torchscript model is not performing as expected.\n"
159+
"Consider using scripting if you used tracing, or investigate further."
160+
)

Diff for: examples/1_ResNet18/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
torchvision

Diff for: examples/1_ResNet18/resnet18.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Load and run pretrained ResNet-18 from TorchVision."""
2+
3+
import torch
4+
import torch.nn.functional as F
5+
import torchvision
6+
7+
8+
# Initialize everything
9+
def initialize():
10+
"""
11+
Download pre-trained ResNet-18 model and prepare for inference.
12+
13+
Returns
14+
-------
15+
model : torch.nn.Module
16+
"""
17+
18+
# Load a pre-trained PyTorch model
19+
print("Loading pre-trained ResNet-18 model...", end="")
20+
model = torchvision.models.resnet18(pretrained=True)
21+
print("done.")
22+
23+
# Switch-off some specific layers/parts of the model that behave
24+
# differently during training and inference
25+
model.eval()
26+
27+
return model
28+
29+
30+
def run_model(model):
31+
"""
32+
Run the pre-trained ResNet-18 with dummy input of ones.
33+
34+
Parameters
35+
----------
36+
model : torch.nn.Module
37+
"""
38+
39+
print("Running ResNet-18 model for ones...", end="")
40+
dummy_input = torch.ones(1, 3, 224, 224)
41+
output = model(dummy_input)
42+
top5 = F.softmax(output, dim=1).topk(5).indices
43+
print("done.")
44+
45+
print(f"Top 5 results:\n {top5}")
46+
47+
48+
if __name__ == "__main__":
49+
rn_model = initialize()
50+
run_model(rn_model)

0 commit comments

Comments
 (0)