-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] robustness leaderboard scripts
- Loading branch information
Showing
12 changed files
with
736 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Other imports | ||
import os | ||
import importlib | ||
import argparse | ||
|
||
# Local imports | ||
from simple_pendulum.analysis.benchmark import benchmarker | ||
|
||
from sim_parameters import ( | ||
mass, | ||
length, | ||
damping, | ||
gravity, | ||
coulomb_fric, | ||
torque_limit, | ||
inertia, | ||
dt, | ||
t_final, | ||
integrator, | ||
benchmark_iterations, | ||
) | ||
|
||
|
||
def compute_leaderboard_data(data_dir, con_filename): | ||
controller_arg = con_filename[:-3] | ||
controller_name = controller_arg[4:] | ||
|
||
save_dir = f"{data_dir}/{controller_name}" | ||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
|
||
imp = importlib.import_module(controller_arg) | ||
|
||
controller = imp.controller | ||
|
||
ben = benchmarker( | ||
dt=dt, | ||
max_time=t_final, | ||
integrator=integrator, | ||
benchmark_iterations=benchmark_iterations, | ||
) | ||
|
||
ben.init_pendulum( | ||
mass=mass, | ||
length=length, | ||
inertia=inertia, | ||
damping=damping, | ||
coulomb_friction=coulomb_fric, | ||
gravity=gravity, | ||
torque_limit=torque_limit, | ||
) | ||
|
||
ben.set_controller(controller) | ||
|
||
ben.benchmark( | ||
check_speed=False, | ||
check_energy=False, | ||
check_time=False, | ||
check_smoothness=False, | ||
check_consistency=True, | ||
check_robustness=True, | ||
check_sensitivity=True, | ||
check_torque_limit=True, | ||
save_path=os.path.join(save_dir, "benchmark.yml"), | ||
) | ||
|
||
if os.path.exists(f"readmes/{controller_name}.md"): | ||
os.system(f"cp readmes/{controller_name}.md {save_dir}/README.md") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--controller", | ||
dest="con_filename", | ||
help="Controller file containing the initialization of the controller.", | ||
default="con_energyshaping_lqr.py", | ||
required=True, | ||
) | ||
|
||
con_filename = parser.parse_args().con_filename | ||
|
||
data_dir = "data" | ||
if not os.path.exists(data_dir): | ||
os.makedirs(data_dir) | ||
|
||
print(f"Simulating new controller {con_filename}") | ||
compute_leaderboard_data(data_dir, con_filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
|
||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | ||
import tensorflow as tf | ||
|
||
tf.get_logger().setLevel("ERROR") | ||
|
||
from simple_pendulum.controllers.ddpg.ddpg_controller import ddpg_controller | ||
from sim_parameters import torque_limit | ||
|
||
name = "ddpg" | ||
leaderboard_config = { | ||
"csv_path": name + "/sim_swingup.csv", | ||
"name": name, | ||
"simple_name": "DDPG", | ||
"short_description": "RL Policy learned with Deep Deterministic Policy Gradient.", | ||
"readme_path": f"readmes/{name}.md", | ||
"username": "fwiebe", | ||
} | ||
|
||
torque_limit = 1.0 | ||
model_path = "../../data/models/ddpg_model/actor" | ||
controller = ddpg_controller( | ||
model_path=model_path, torque_limit=torque_limit, state_representation=3 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import numpy as np | ||
|
||
from simple_pendulum.trajectory_optimization.direct_collocation.direct_collocation import ( | ||
DirectCollocationCalculator, | ||
) | ||
from simple_pendulum.controllers.tvlqr.tvlqr import TVLQRController | ||
from simple_pendulum.controllers.lqr.lqr_controller import LQRController | ||
from simple_pendulum.controllers.combined_controller import CombinedController | ||
from simple_pendulum.utilities.process_data import prepare_empty_data_dict | ||
|
||
from sim_parameters import ( | ||
mass, | ||
length, | ||
damping, | ||
gravity, | ||
coulomb_fric, | ||
torque_limit, | ||
inertia, | ||
dt, | ||
t_final, | ||
t0, | ||
x0, | ||
goal, | ||
integrator, | ||
) | ||
|
||
name = "dircol_tvlqr_lqr" | ||
leaderboard_config = { | ||
"csv_path": name + "/sim_swingup.csv", | ||
"name": name, | ||
"simple_name": "Direct collocation and TVLQR", | ||
"short_description": "Direct collocation trajectory stabilized with time-varying LQR.", | ||
"readme_path": f"readmes/{name}.md", | ||
"username": "fwiebe", | ||
} | ||
|
||
# direct collocation parameters | ||
N = 21 | ||
max_dt = 0.5 | ||
|
||
torque_limit = 1.5 | ||
|
||
#################### | ||
# Compute trajectory | ||
#################### | ||
|
||
|
||
dircal = DirectCollocationCalculator() | ||
dircal.init_pendulum( | ||
mass=mass, | ||
length=length, | ||
damping=damping, | ||
gravity=gravity, | ||
torque_limit=torque_limit, | ||
) | ||
x_trajectory, dircol, result = dircal.compute_trajectory( | ||
N=N, max_dt=max_dt, start_state=x0, goal_state=goal | ||
) | ||
T, X, XD, U = dircal.extract_trajectory( | ||
x_trajectory, dircol, result, N=int(x_trajectory.end_time() / dt) | ||
) | ||
|
||
|
||
# save results | ||
data_dict = prepare_empty_data_dict(dt, t_final) | ||
data_dict["des_time"] = T | ||
data_dict["des_pos"] = X | ||
data_dict["des_vel"] = XD | ||
data_dict["des_tau"] = U | ||
|
||
controller1 = TVLQRController( | ||
data_dict=data_dict, | ||
mass=mass, | ||
length=length, | ||
damping=damping, | ||
gravity=gravity, | ||
torque_limit=torque_limit, | ||
) | ||
|
||
controller2 = LQRController( | ||
mass=mass, | ||
length=length, | ||
damping=damping, | ||
coulomb_fric=coulomb_fric, | ||
gravity=gravity, | ||
torque_limit=torque_limit, | ||
Q=np.diag([10, 1]), | ||
R=np.array([[1]]), | ||
compute_RoA=False, | ||
) | ||
|
||
|
||
def condition1(meas_pos, meas_vel, meas_tau, meas_time): | ||
return False | ||
|
||
|
||
def condition2(meas_pos, meas_vel, meas_tau, meas_time): | ||
goal = np.asarray([np.pi, 0.0]) | ||
delta_pos = meas_pos - goal[0] | ||
delta_pos_wrapped = (delta_pos + np.pi) % (2 * np.pi) - np.pi | ||
if np.abs(delta_pos_wrapped) < 0.1 and np.abs(meas_vel) < 0.1: | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
controller = CombinedController( | ||
controller1=controller1, | ||
controller2=controller2, | ||
condition1=condition1, | ||
condition2=condition2, | ||
) | ||
controller.set_goal(goal) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy as np | ||
|
||
# Local imports | ||
from simple_pendulum.controllers.energy_shaping.energy_shaping_controller import ( | ||
EnergyShapingAndLQRController, | ||
) | ||
|
||
from sim_parameters import ( | ||
mass, | ||
length, | ||
damping, | ||
gravity, | ||
coulomb_fric, | ||
torque_limit, | ||
inertia, | ||
dt, | ||
t_final, | ||
t0, | ||
x0, | ||
goal, | ||
integrator, | ||
) | ||
|
||
name = "energyshaping_lqr" | ||
leaderboard_config = { | ||
"csv_path": name + "/sim_swingup.csv", | ||
"name": name, | ||
"simple_name": "Energy Shaping and LQR", | ||
"short_description": "Energy shaping for swingup and LQR for stabilization.", | ||
"readme_path": f"readmes/{name}.md", | ||
"username": "fwiebe", | ||
} | ||
|
||
torque_limit = 1.0 | ||
|
||
controller = EnergyShapingAndLQRController( | ||
mass=mass, | ||
length=length, | ||
damping=damping, | ||
coulomb_fric=coulomb_fric, | ||
gravity=gravity, | ||
torque_limit=torque_limit, | ||
k=1.0, | ||
Q=np.diag([10, 1]), | ||
R=np.array([[1]]), | ||
compute_RoA=False, | ||
) | ||
controller.set_goal(goal) |
Oops, something went wrong.