This repository contains code for "Fighting Uncertainty with Gradients: Offline Reinforcement Learning via Diffusion Score Matching".
We perform uncertainty penalization optimal control problems of the folloinwg form,
max_\theta E_{x_i\sim\rho} [\sum_t r(x_t,u_t) + \beta \sum_t \log p(x_t,u_t)
s.t. x_{t+1} = f(x_t, u_t) forall t,
u_t = pi_\theta(x_t) forall t,
x_0 = x_i
where θ are policy parameters, r is the reward, and p is the perturbed empirical distribution of data that encourages rollout trajectories to stay close to data.
Our codebase supports general feedback policy optimization, but examples mainly evolve around open-loop planning. This is a special case of uncertainty-penalized optimal control where the policy is parametrized as an open-loop sequence of inputs.
This repo is mainly written in torch
, and heavily uses wandb
and hydra-core
. For some robotic examples, drake
might be required as a dependency.
To install these dependencies and set the path, simply run
python -m pip install -r requirements.txt
after cloning the repo, and add the python path to ~/.bashrc
. Since this repo relies on calling lines such as import examples
, we recommend putting this line at the
end of the bashrc file.
export PYTHONPATH=${HOME}/score_po:${PYTHONPATH}
We use hydra for our examples, and users are required to have a config file. Add your own user config file under config/user
for each example, and modify the config files to have your user name.
For example, to run examples/cartpole/learn_model.py
,
- Add a profile to
examples/cartpole/config/user
asnew_user.yaml
, following patterns ofterry.yaml
. - In
examples/cartpole/config/learn_model.yaml
, set
defaults:
- user: new_user
- Run
python examples/cartpole/learn_model.py
from the cloned directory.
We use pytest
for testing and CI. To run tests, do
pytest .
from the cloned directory.
All the examples can be found in the examples folder with instructions on how to run.
- Simple1D
- Cart-pole system
- The pixel-space single integrator: use branch
pixels_glen
. - D4RL Mujoco Benchmark
- Box-Keypoint Pushing Example: for hardware code, use branch
lcm_hardware
.