Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在增加reward函数并改变权重的层面,增加了踢球环境训练 #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dribblebot/envs/base/legged_robot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ class rewards(PrefixProto, cli=False):
front_target = [[0.17, -0.09, 0]]
estimation_bonus_dims = []
estimation_bonus_weights = []


constrict = False
constrict_indices = []
Expand Down Expand Up @@ -346,6 +347,7 @@ class reward_scales(ParamsProto, cli=False):
tracking_contacts = 0.
tracking_contacts_shaped = 0.
tracking_contacts_shaped_force = 0.
tracking_contacts_shaped_force_for_kicking = 0.
tracking_contacts_shaped_vel = 0.
jump = 0.0
energy = 0.0
Expand All @@ -366,6 +368,7 @@ class reward_scales(ParamsProto, cli=False):
dribbling_robot_ball_vel = 0.0
dribbling_robot_ball_pos = 0.0
dribbling_ball_vel = 0.0
kicking_ball_vel = 0.0
dribbling_robot_ball_yaw = 0.0
dribbling_ball_vel_norm = 0.0
dribbling_ball_vel_angle = 0.0
Expand Down
28 changes: 27 additions & 1 deletion dribblebot/rewards/soccer_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def _reward_tracking_contacts_shaped_force(self): # √
reward += - (1 - desired_contact[:, i]) * (
1 - torch.exp(-1 * foot_forces[:, i] ** 2 / self.env.cfg.rewards.gait_force_sigma))
return reward / 4

def _reward_tracking_contacts_shaped_force_for_kicking(self):# designed for kicking, more powerful kicking is encouraged
foot_forces = torch.norm(self.env.contact_forces[:, self.env.feet_indices, :], dim=-1)
desired_contact = self.env.desired_contact_states

reward = 0
for i in range(4):
reward += -(1 - desired_contact[:, i]) * (
torch.exp(-1 * foot_forces[:, i] ** 2 / self.env.cfg.rewards.gait_force_sigma))
return reward / 4

def _reward_tracking_contacts_shaped_vel(self): # √
foot_velocities = torch.norm(self.env.foot_velocities, dim=2).view(self.env.num_envs, -1)
Expand Down Expand Up @@ -96,6 +106,7 @@ def _reward_dribbling_robot_ball_vel(self): # √
velocity_concatenation = torch.cat((torch.zeros(self.env.num_envs,1, device=self.env.device), ball_robot_velocity_projection.unsqueeze(dim=-1)), dim=-1)
rew_dribbling_robot_ball_vel=torch.exp(-delta_dribbling_robot_ball_vel* torch.pow(torch.max(velocity_concatenation,dim=-1).values, 2) )
return rew_dribbling_robot_ball_vel


# encourage robot near ball
# r_cp
Expand All @@ -115,7 +126,22 @@ def _reward_dribbling_ball_vel(self): # √
lin_vel_error = torch.sum(torch.square(self.env.commands[:, :2] - self.env.object_lin_vel[:, :2]), dim=1)
# rew_dribbling_ball_vel = torch.exp(-lin_vel_error / (self.env.cfg.rewards.tracking_sigma*2))
return torch.exp(-lin_vel_error / (self.env.cfg.rewards.tracking_sigma*2))


def _reward_kicking_ball_vel(self): # √ to encourage the velocity of the ball to be in the direction of the command
# 计算命令速度的单位向量
command_velocity = self.env.commands[:, :2]
command_velocity_norm = torch.norm(command_velocity, dim=1, keepdim=True)
command_velocity_unit = command_velocity / (command_velocity_norm + 1e-8) # 避免除以零

# 计算球在命令速度方向上的速度分量
ball_velocity = self.env.object_lin_vel[:, :2]
ball_velocity_in_command_direction = torch.sum(ball_velocity * command_velocity_unit, dim=1)

# 设计奖励函数,鼓励球在命令速度方向上的速度越大越好,并将奖励值限制在0到1之间
reward = torch.clamp(1 - torch.exp(-ball_velocity_in_command_direction), 0, 1)

return reward

def _reward_dribbling_robot_ball_yaw(self): # TODO: something wrong, norm会出现为0的情况
robot_ball_vec = self.env.object_pos_world_frame[:,0:2] - self.env.base_pos[:,0:2]
d_robot_ball=robot_ball_vec / torch.norm(robot_ball_vec, dim=-1).unsqueeze(dim=-1)
Expand Down
300 changes: 300 additions & 0 deletions scripts/go2/train_kicking.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script could run correctly now?

Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
def train_go2(use_wandb=False, resume_flag=False, exp_name="", device='cuda:0', number_envs=1024):

import isaacgym
assert isaacgym
import torch
import wandb

from dribblebot.envs.base.legged_robot_config import Cfg
from dribblebot.envs.go2.go2_config import config_go2
from dribblebot.envs.go2.velocity_tracking import VelocityTrackingEasyEnv

from dribblebot_learn.ppo_cse import Runner
from dribblebot.envs.wrappers.history_wrapper import HistoryWrapper
from dribblebot_learn.ppo_cse.actor_critic import AC_Args
from dribblebot_learn.ppo_cse.ppo import PPO_Args
from dribblebot_learn.ppo_cse import RunnerArgs

config_go2(Cfg)
Cfg.env.num_envs = number_envs # default: 4096

RunnerArgs.resume = resume_flag # use pretrain or not
# RunnerArgs.resume_path = "improbableailab/dribbling/j34kr9ds"
# RunnerArgs.resume_checkpoint = 'tmp/legged_data/ac_weights_last.pt'
RunnerArgs.resume_checkpoint = '/home/zdj/Codes/dribblebot/runs/improbableailab/dribbling/bvggoq26/dribbling_pretrained/ac_weights.pt' # TODO: change this path

Cfg.robot.name = "go2"
Cfg.sensors.sensor_names = [
"ObjectSensor",
"OrientationSensor",
"RCSensor",
"JointPositionSensor",
"JointVelocitySensor",
"ActionSensor",
"LastActionSensor",
"ClockSensor",
"YawSensor",
"TimingSensor",
]
Cfg.sensors.sensor_args = {
"ObjectSensor": {},
"OrientationSensor": {},
"RCSensor": {},
"JointPositionSensor": {},
"JointVelocitySensor": {},
"ActionSensor": {},
"LastActionSensor": {"delay": 1},
"ClockSensor": {},
"YawSensor": {},
"TimingSensor":{},
}
Cfg.sensors.privileged_sensor_names = {
"BodyVelocitySensor": {},
"ObjectVelocitySensor": {},
}
Cfg.sensors.privileged_sensor_args = {
"BodyVelocitySensor": {},
"ObjectVelocitySensor": {},
}

Cfg.commands.num_lin_vel_bins = 30
Cfg.commands.num_ang_vel_bins = 30
Cfg.curriculum_thresholds.tracking_ang_vel = 0.7
Cfg.curriculum_thresholds.tracking_lin_vel = 0.8
Cfg.curriculum_thresholds.tracking_contacts_shaped_vel = 0.90
Cfg.curriculum_thresholds.tracking_contacts_shaped_force = 0.90
Cfg.curriculum_thresholds.dribbling_ball_vel = 0.8

Cfg.commands.distributional_commands = True

Cfg.domain_rand.lag_timesteps = 6
Cfg.domain_rand.randomize_lag_timesteps = True
Cfg.control.control_type = "actuator_net"

Cfg.domain_rand.randomize_rigids_after_start = False
# Cfg.domain_rand.randomize_friction_indep = False
Cfg.domain_rand.randomize_restitution = False # True
Cfg.domain_rand.restitution_range = [0.0, 0.4]
Cfg.domain_rand.randomize_base_mass = True
Cfg.domain_rand.added_mass_range = [-1.0, 3.0]
Cfg.domain_rand.randomize_gravity = False
Cfg.domain_rand.gravity_range = [-1.0, 1.0]
Cfg.domain_rand.gravity_rand_interval_s = 8.0
Cfg.domain_rand.gravity_impulse_duration = 0.99
Cfg.domain_rand.randomize_com_displacement = False
Cfg.domain_rand.com_displacement_range = [-0.15, 0.15]
# Cfg.domain_rand.randomize_ground_friction = True
# Cfg.domain_rand.ground_friction_range = [0.0, 0.0]
Cfg.domain_rand.randomize_motor_strength = True
Cfg.domain_rand.motor_strength_range = [0.99, 1.01]
Cfg.domain_rand.randomize_motor_offset = True
Cfg.domain_rand.motor_offset_range = [-0.002, 0.002]
Cfg.domain_rand.push_robots = False
Cfg.domain_rand.randomize_Kp_factor = True
Cfg.domain_rand.randomize_Kd_factor = True
Cfg.domain_rand.randomize_ball_drag = True
Cfg.domain_rand.drag_range = [0.1, 0.8]
Cfg.domain_rand.ball_drag_rand_interval_s = 15.0

Cfg.env.num_observation_history = 15
Cfg.reward_scales.feet_contact_forces = 0.0

Cfg.commands.exclusive_phase_offset = False
Cfg.commands.pacing_offset = False
Cfg.commands.balance_gait_distribution = False
Cfg.commands.binary_phases = False
Cfg.commands.gaitwise_curricula = False

###############################
# soccer dribbling configuration
###############################

# ball parameters
Cfg.env.add_balls = True

# domain randomization ranges
Cfg.domain_rand.rand_interval_s = 6
Cfg.domain_rand.randomize_friction = False # True # TODO: randomize friction
Cfg.domain_rand.friction_range = [0.0, 1.5]
Cfg.domain_rand.randomize_ground_friction = True # TODO: randomize ground friction
Cfg.domain_rand.ground_friction_range = [0.7, 4.0] # default: [0.7, 4.0] change2: [0.4, 1.5]
Cfg.domain_rand.restitution_range = [0.0, 0.4]
Cfg.domain_rand.added_mass_range = [-1.0, 3.0]
Cfg.domain_rand.gravity_range = [-1.0, 1.0]
Cfg.domain_rand.motor_strength_range = [0.99, 1.01]
Cfg.domain_rand.motor_offset_range = [-0.002, 0.002]
Cfg.domain_rand.tile_roughness_range = [0.0, 0.0]

# privileged obs in use
Cfg.env.num_privileged_obs = 6
Cfg.env.priv_observe_ball_drag = True

# sensory observation
Cfg.commands.num_commands = 15
Cfg.env.episode_length_s = 40.
Cfg.env.num_observations = 75

# terrain configuration
Cfg.terrain.border_size = 0.0
Cfg.terrain.mesh_type = "boxes_tm"
Cfg.terrain.num_cols = 20
Cfg.terrain.num_rows = 20
Cfg.terrain.terrain_length = 5.0
Cfg.terrain.terrain_width = 5.0
Cfg.terrain.num_border_boxes = 5.0
Cfg.terrain.x_init_range = 0.2
Cfg.terrain.y_init_range = 0.2
Cfg.terrain.teleport_thresh = 0.3
Cfg.terrain.teleport_robots = False
Cfg.terrain.center_robots = False
Cfg.terrain.center_span = 3
Cfg.terrain.horizontal_scale = 0.05
Cfg.terrain.terrain_proportions = [1.0, 0.0, 0.0, 0.0, 0.0]
Cfg.terrain.curriculum = False
Cfg.terrain.difficulty_scale = 1.0
Cfg.terrain.max_step_height = 0.26
Cfg.terrain.min_step_run = 0.25
Cfg.terrain.max_step_run = 0.4
Cfg.terrain.max_init_terrain_level = 1

# terminal conditions
Cfg.rewards.use_terminal_body_height = True
Cfg.rewards.terminal_body_height = 0.2
Cfg.rewards.use_terminal_roll_pitch = False
Cfg.rewards.terminal_body_ori = 0.5

# command sampling
Cfg.commands.resampling_time = 7
Cfg.commands.heading_command = False

Cfg.commands.lin_vel_x = [-1.5, 1.5]
Cfg.commands.lin_vel_y = [-1.5, 1.5]
Cfg.commands.ang_vel_yaw = [-0.0, 0.0]
Cfg.commands.body_height_cmd = [-0.05, 0.05]
Cfg.commands.gait_frequency_cmd_range = [3.0, 3.0]
Cfg.commands.gait_phase_cmd_range = [0.5, 0.5]
Cfg.commands.gait_offset_cmd_range = [0.0, 0.0]
Cfg.commands.gait_bound_cmd_range = [0.0, 0.0]
Cfg.commands.gait_duration_cmd_range = [0.5, 0.5]
Cfg.commands.footswing_height_range = [0.09, 0.09]
Cfg.commands.body_pitch_range = [-0.0, 0.0]
Cfg.commands.body_roll_range = [-0.0, 0.0]
Cfg.commands.stance_width_range = [0.0, 0.1]
Cfg.commands.stance_length_range = [0.0, 0.1]

Cfg.commands.limit_vel_x = [-1.5, 1.5]
Cfg.commands.limit_vel_y = [-1.5, 1.5]
Cfg.commands.limit_vel_yaw = [-0.0, 0.0]
Cfg.commands.limit_body_height = [-0.05, 0.05]
Cfg.commands.limit_gait_frequency = [3.0, 3.0]
Cfg.commands.limit_gait_phase = [0.5, 0.5]
Cfg.commands.limit_gait_offset = [0.0, 0.0]
Cfg.commands.limit_gait_bound = [0.0, 0.0]
Cfg.commands.limit_gait_duration = [0.5, 0.5]
Cfg.commands.limit_footswing_height = [0.09, 0.09]
Cfg.commands.limit_body_pitch = [-0.0, 0.0]
Cfg.commands.limit_body_roll = [-0.0, 0.0]
Cfg.commands.limit_stance_width = [0.0, 0.1]
Cfg.commands.limit_stance_length = [0.0, 0.1]

Cfg.commands.num_bins_vel_x = 1
Cfg.commands.num_bins_vel_y = 1
Cfg.commands.num_bins_vel_yaw = 1
Cfg.commands.num_bins_body_height = 1
Cfg.commands.num_bins_gait_frequency = 1
Cfg.commands.num_bins_gait_phase = 1
Cfg.commands.num_bins_gait_offset = 1
Cfg.commands.num_bins_gait_bound = 1
Cfg.commands.num_bins_gait_duration = 1
Cfg.commands.num_bins_footswing_height = 1
Cfg.commands.num_bins_body_roll = 1
Cfg.commands.num_bins_body_pitch = 1
Cfg.commands.num_bins_stance_width = 1

Cfg.rewards.constrict = False

# reward function
Cfg.reward_scales.orientation = -5.0 # TODO default: -5.0 change2: -20.0
Cfg.reward_scales.torques = -0.0001
Cfg.reward_scales.dof_vel = -0.0001
Cfg.reward_scales.dof_acc = -2.5e-7
Cfg.reward_scales.collision = -5.0
Cfg.reward_scales.action_rate = -0.01
Cfg.reward_scales.tracking_contacts_shaped_force = 0.0
Cfg.reward_scales.tracking_contacts_shaped_force_for_kicking = 4.0
Cfg.reward_scales.tracking_contacts_shaped_vel = 4.0
Cfg.reward_scales.dof_pos_limits = -10.0
Cfg.reward_scales.dof_pos = -0.05
Cfg.reward_scales.action_smoothness_1 = -0.1
Cfg.reward_scales.action_smoothness_2 = -0.1
Cfg.reward_scales.dribbling_robot_ball_vel = 0.5
Cfg.reward_scales.dribbling_robot_ball_pos = 0.0
Cfg.reward_scales.dribbling_ball_vel = 0.0
Cfg.reward_scales.kicking_ball_vel = 4.0
Cfg.reward_scales.dribbling_robot_ball_yaw = 20.0 # TODO default: 4.0 change2: 20.0
Cfg.reward_scales.dribbling_ball_vel_norm = 0.0
Cfg.reward_scales.dribbling_ball_vel_angle = 4.0
Cfg.reward_scales.tracking_lin_vel = 0.0
Cfg.reward_scales.tracking_ang_vel = 0.0
Cfg.reward_scales.lin_vel_z = 0.0
Cfg.reward_scales.ang_vel_xy = 0.0
Cfg.reward_scales.feet_air_time = 0.0

Cfg.rewards.kappa_gait_probs = 0.07
Cfg.rewards.gait_force_sigma = 100.
Cfg.rewards.gait_vel_sigma = 10.

Cfg.rewards.reward_container_name = "SoccerRewards"
Cfg.rewards.only_positive_rewards = False
Cfg.rewards.only_positive_rewards_ji22_style = True
Cfg.rewards.sigma_rew_neg = 0.02

# normalization
Cfg.normalization.friction_range = [0, 1]
Cfg.normalization.ground_friction_range = [0.7, 4.0] # TODO default: [0.7, 4.0] change2: [0.4, 1.5]
Cfg.terrain.yaw_init_range = 3.14
Cfg.normalization.clip_actions = 10.0

# reward function (not in use)
Cfg.reward_scales.feet_slip = -0.0
Cfg.reward_scales.jump = 0.0
Cfg.reward_scales.base_height = -0.0
Cfg.reward_scales.feet_impact_vel = -0.0
Cfg.reward_scales.feet_air_time = 0.0

Cfg.asset.terminate_after_contacts_on = []

AC_Args.adaptation_labels = []
AC_Args.adaptation_dims = []

RunnerArgs.save_video_interval = 500

wandb.init(
# set the wandb project where this run will be logged
mode="disabled" if use_wandb is False else "online",
project="dribbling",
entity="xander2077",
name=exp_name,
# track hyperparameters and run metadata
config={
"AC_Args": vars(AC_Args),
"PPO_Args": vars(PPO_Args),
"RunnerArgs": vars(RunnerArgs),
"Cfg": vars(Cfg),
}
)

env = VelocityTrackingEasyEnv(sim_device=device, headless=False, cfg=Cfg)
env = HistoryWrapper(env)
runner = Runner(env, device=device)
runner.learn(num_learning_iterations=1000000, init_at_random_ep_len=True, eval_freq=100)


if __name__ == '__main__':
from pathlib import Path
from dribblebot import MINI_GYM_ROOT_DIR

stem = Path(__file__).stem

# to see the environment rendering, set headless=False
train_go2(use_wandb=True, resume_flag=False, exp_name="Test", device='cuda:0', number_envs=256)
2 changes: 1 addition & 1 deletion scripts/train_dribbling.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def train_go1(headless=True):
}
)

device = 'cuda:1'
device = 'cuda:0'
# device = 'cpu'
env = VelocityTrackingEasyEnv(sim_device=device, headless=True, cfg=Cfg)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep it

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
'tqdm',
'matplotlib',
'numpy==1.23.5',
'wandb', # 'wandb==0.15.0',
'wandb==0.15.0', # 'wandb==0.15.0',
'wandb_osh',
#'moviepy',
'imageio'
Expand Down