👓 Exporting habitat-lab policies/networks in torchscript to load into spotsim2real environment without libraries and environment conflicts
You might want to train new policies or networks in habitat-lab. However, habitat-lab conda environment packages and spot-ros (used for spot-sim2real) environment packages might create version incompatabilities. Thus, we export a model in intermediate representation (IR) using torchscript module provided in Pytorch. Disentangling the deployment and development environment of a model, provides freedom to the model developer. We made a conversion script to convert mobile-gaze policy that was trained in new version of habitat-lab to torchscript model.
To use pytorch_to_torchscript.py , use python pytorch_to_torchscipt.py -c conversionparams.py
add the relevant conversion parameters in conversionparams.yaml
In general, these are the steps you can follow for conversion.
- Load the pytorch model with class files, transfer the model on cuda
- Pass some random input tensor to the model and trace it's forward pass using
torch.jit.trace
, usage example can be found here - Save the traced model as modelXX.torchscript, replace modelXX with desired name
- To load the model in spotsim2real use
torch.jit.load(path/to/saved/torchscript/model, map_location="cuda:0")
- To download existing torchscript checkpoints run these commands from spot-sim2real folder
cd spot_rl_experiments/weights
git clone git clone https://huggingface.co/spaces/jimmytyyang/spot-sim2real-data
unzip spot-sim2real-data/weight/torchscript.zip && rm -rf spot-sim2real-torchscript-data && cd ../..
We encountered cuda error when setting up the recent habitat-lab version, it installs a recent version (2.2.1) of pytorch and CUDA 11.8. However, the hardware driver was older than 11.8 thus torch.cuda.is_available()
was False
and showing driver old error.
To fix that first uninstall pytorch using pip uninstall pytorch torchvision torchaudio
in your habitat-lab conda env then run the following conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
in same habitat-lab env (this is the pytorch and cuda version we use for spot-sim2real/spot-ros env)