Official implementation of the paper "Neural Fields with Thermal Activations for Arbitrary-Scale Super-Resolution" by Alexander Becker*, Rodrigo Daudt*, Nando Metzger, Jan Dirk Wegner, Konrad Schindler (* equal contribution)
You need a Python 3.10 environment (e.g., installed via conda) on Linux as well as an NVIDIA GPU (or cloud TPU). Then install packages via pip:
> pip install --upgrade pip
> pip install -r requirements_cu11.txt # CUDA 11
# or
> pip install -r requirements_cu12.txt # CUDA 12
# or
> pip install -r requirements_tpu.txt # TPU
Download checkpoints [here]. Super-resolve any image with, e.g.:
> ./super_resolve.py IN_FILE OUT_FILE --scale 3.14 --checkpoint checkpoints/thera-L-swin-ir.pkl --backbone swin-ir --model-size L
You can evaluate the models on datasets using the run_eval.py
script, e.g.:
> python run_eval.py --checkpoint checkpoints/thera-M-edsr-baseline.pkl --data-dir path_to_data_parent_folder --eval-sets data_folder_1, data_folder_2, ...
Check the arguments in args.py
(bottom of file) for all testing options.
Train and evaluate using
> python run_train_and_eval.py --data-dir path_to_data_parent_folder --train-set train_data_folder --val-set val_data_folder
Check the arguments in args.py
for all training options. Our implementation will automatically shard over all available devices, this can be overwritten by manually setting --n-devices
or CUDA_VISIBLE_DEVICES
.
- Disable pre-allocation of entire VRAM:
XLA_PYTHON_CLIENT_PREALLOCATE=false
- Force GPU determinism (slow):
XLA_FLAGS=--xla_gpu_deterministic_ops=true
- Disable jitting for debugging:
JAX_DISABLE_JIT=1
Please cite our paper if you found our work helpful:
@article{becker2023neural,
title={Neural Fields with Thermal Activations for Arbitrary-Scale Super-Resolution},
author={Becker, Alexander and Daudt, Rodrigo Caye and Metzger, Nando and Wegner, Jan Dirk and Schindler, Konrad},
journal={arXiv preprint arXiv:2311.17643},
year={2023}
}