Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nav2_mppi_controller/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ add_library(mppi_critics SHARED
src/critics/prefer_forward_critic.cpp
src/critics/twirling_critic.cpp
src/critics/velocity_deadband_critic.cpp
src/critics/direction_change_critic.cpp
)
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND APPLE)
# Apple Clang: use C++20 and optimization, omit -fconcepts
Expand Down
4 changes: 4 additions & 0 deletions nav2_mppi_controller/critics.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,9 @@
<description>mppi critic for restricting command velocities in deadband range</description>
</class>

<class type="mppi::critics::DirectionChangeCritic" base_class_type="mppi::critics::CriticFunction">
<description>mppi critic for penalizing changes in driving direction</description>
</class>

</library>
</class_libraries>
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2024 Enway GmbH, Adi Vardi
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef NAV2_MPPI_CONTROLLER__CRITICS__DIRECTION_CHANGE_CRITIC_HPP_
#define NAV2_MPPI_CONTROLLER__CRITICS__DIRECTION_CHANGE_CRITIC_HPP_

#include "nav2_mppi_controller/critic_function.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"

namespace mppi::critics
{

/**
* @class mppi::critics::DirectionChangeCritic
* @brief Critic objective function for penalizing changes in driving direction.
*/
class DirectionChangeCritic : public CriticFunction
{
public:
/**
* @brief Initialize critic
*/
void initialize() override;

/**
* @brief Evaluate cost related to changing driving direction
*
* @param costs [out] add cost values to this tensor
*/
void score(CriticData & data) override;

protected:
unsigned int power_{0};
float weight_{0};
float threshold_to_consider_{0};
};

} // namespace mppi::critics

#endif // NAV2_MPPI_CONTROLLER__CRITICS__DIRECTION_CHANGE_CRITIC_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,17 @@ class PathAlignCritic : public CriticFunction
size_t offset_from_furthest_{0};
int trajectory_point_step_{0};
float threshold_to_consider_{0};
float occupancy_check_min_distance_{0};
float max_path_occupancy_ratio_{0};
bool use_path_orientations_{false};
unsigned int power_{0};
float weight_{0};

bool visualize_furthest_point_{false};
nav2::Publisher<geometry_msgs::msg::PoseStamped>::SharedPtr furthest_point_pub_;

bool visualize_occupancy_check_distance_{false};
nav2::Publisher<geometry_msgs::msg::PoseStamped>::SharedPtr occupancy_check_dist_pub_;
};

} // namespace mppi::critics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ struct State
Eigen::ArrayXXf cwz;

geometry_msgs::msg::PoseStamped pose;
// current speed or last command published, depends on open_loop setting
geometry_msgs::msg::Twist speed;
geometry_msgs::msg::Twist robot_speed; // current speed from odometry
float local_path_length;

/**
Expand Down
69 changes: 69 additions & 0 deletions nav2_mppi_controller/src/critics/direction_change_critic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) 2024 Enway GmbH, Adi Vardi
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "nav2_mppi_controller/critics/direction_change_critic.hpp"

#include <Eigen/Dense>

namespace mppi::critics
{

void DirectionChangeCritic::initialize()
{
auto getParentParam = parameters_handler_->getParamGetter(parent_name_);
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 5.0f);
getParam(
threshold_to_consider_,
"threshold_to_consider", 0.5f);

RCLCPP_INFO(
logger_, "DirectionChangeCritic instantiated with %d power and %f weight.", power_, weight_);
}

void DirectionChangeCritic::score(CriticData & data)
{
if (!enabled_) {
return;
}

if (data.state.local_path_length < threshold_to_consider_) {
return;
}

// Penalize the magnitude of velocity difference when crossing zero (direction change)
// Calculate |vx - current_speed| only where signs differ, otherwise 0
// Use robot_speed from feedback, as it better represents the actual direction of motion
constexpr size_t penalize_up_to_idx = 2;
const float current_speed = data.state.robot_speed.linear.x;
// Process in-place using Eigen views to avoid allocations
auto vx_view = data.state.vx.leftCols(penalize_up_to_idx);

if (power_ > 1u) {
data.costs += ((vx_view * current_speed < 0.0f).select(
(vx_view - current_speed).abs(), 0.0f).rowwise().sum() * weight_).pow(power_);
} else {
data.costs += (vx_view * current_speed < 0.0f).select(
(vx_view - current_speed).abs(), 0.0f).rowwise().sum() * weight_;
}
}

} // namespace mppi::critics

#include <pluginlib/class_list_macros.hpp>

PLUGINLIB_EXPORT_CLASS(
mppi::critics::DirectionChangeCritic,
mppi::critics::CriticFunction)
97 changes: 74 additions & 23 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void PathAlignCritic::initialize()
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(power_, "cost_power", 1);
getParam(weight_, "cost_weight", 10.0f);
getParam(occupancy_check_min_distance_, "occupancy_check_min_distance", 2.0f);
getParam(max_path_occupancy_ratio_, "max_path_occupancy_ratio", 0.07f);
getParam(offset_from_furthest_, "offset_from_furthest", 20);
getParam(trajectory_point_step_, "trajectory_point_step", 4);
Expand All @@ -41,6 +42,17 @@ void PathAlignCritic::initialize()
}
}

getParam(visualize_occupancy_check_distance_, "visualize_occupancy_check_distance", false);

if (visualize_occupancy_check_distance_) {
auto node = parent_.lock();
if (node) {
occupancy_check_dist_pub_ = node->create_publisher<geometry_msgs::msg::PoseStamped>(
"/critics/PathAlignCritic/occupancy_check_end_point", 1);
occupancy_check_dist_pub_->on_activate();
}
}

RCLCPP_INFO(
logger_,
"ReferenceTrajectoryCritic instantiated with %d power and %f weight",
Expand All @@ -53,48 +65,43 @@ void PathAlignCritic::score(CriticData & data)
return;
}

// Don't apply when first getting bearing w.r.t. the path
// Only apply critic when trajectories reach far enough along the way path.
// This ensures that path alignment is only considered when actually tracking the path
// (e.g. not driving very slow or when first getting bearing w.r.t. the path)
utils::setPathFurthestPointIfNotSet(data);
// Up to furthest only, closest path point is always 0 from path handler
const size_t path_segments_count = *data.furthest_reached_path_point;
float path_segments_flt = static_cast<float>(path_segments_count);

// Visualize target pose if enabled
if (visualize_furthest_point_ && path_segments_count > 0 &&
const auto now = clock_->now();
// Visualize furthest reached pose if enabled
if (visualize_furthest_point_ && *data.furthest_reached_path_point > 0 &&
furthest_point_pub_->get_subscription_count() > 0)
{
auto furthest_point = std::make_unique<geometry_msgs::msg::PoseStamped>();
furthest_point->header.frame_id = costmap_ros_->getGlobalFrameID();
furthest_point->header.stamp = clock_->now();
furthest_point->pose.position.x = data.path.x(path_segments_count);
furthest_point->pose.position.y = data.path.y(path_segments_count);
furthest_point->header.stamp = now;
furthest_point->pose.position.x = data.path.x(*data.furthest_reached_path_point);
furthest_point->pose.position.y = data.path.y(*data.furthest_reached_path_point);
furthest_point->pose.position.z = 0.0;
tf2::Quaternion quat;
quat.setRPY(0.0, 0.0, data.path.yaws(path_segments_count));
quat.setRPY(0.0, 0.0, data.path.yaws(*data.furthest_reached_path_point));
furthest_point->pose.orientation = tf2::toMsg(quat);
furthest_point_pub_->publish(std::move(furthest_point));
}

if (path_segments_count < offset_from_furthest_) {
if (*data.furthest_reached_path_point < offset_from_furthest_) {
return;
}

// Don't apply when dynamic obstacles are blocking significant proportions of the local path
utils::setPathCostsIfNotSet(data, costmap_ros_);
std::vector<bool> & path_pts_valid = *data.path_pts_valid;
float invalid_ctr = 0.0f;
for (size_t i = 0; i < path_segments_count; i++) {
if (!path_pts_valid[i]) {invalid_ctr += 1.0f;}
if (invalid_ctr / path_segments_flt > max_path_occupancy_ratio_ && invalid_ctr > 2.0f) {
return;
}
}

const size_t batch_size = data.trajectories.x.rows();
Eigen::ArrayXf cost(data.costs.rows());
cost.setZero();

// Find integrated distance in the path
// Find integrated arc-length distance along the path = total dist traveled along the path to each
// path point
// loop until end of path, to guarantee don't truncate long trajectories when
// furthest_reached_path_point is small (e.g. when all trajectories curve away from the path)
const size_t path_segments_count = data.path.x.size() - 1;
// initialize the occupancy check id to max, in case the entire path is within the distance
size_t occupancy_check_distance_idx = path_segments_count;
std::vector<float> path_integrated_distances(path_segments_count, 0.0f);
std::vector<utils::Pose2D> path(path_segments_count);
float dx = 0.0f, dy = 0.0f;
Expand All @@ -107,6 +114,46 @@ void PathAlignCritic::score(CriticData & data)
dx = data.path.x(i) - pose.x;
dy = data.path.y(i) - pose.y;
path_integrated_distances[i] = path_integrated_distances[i - 1] + sqrtf(dx * dx + dy * dy);

// find the first path point that is further than
// max(occupancy_check_min_distance_, furthest_reached_path_point)
if (occupancy_check_distance_idx == path_segments_count &&
path_integrated_distances[i] > occupancy_check_min_distance_ &&
i >= *data.furthest_reached_path_point)
{
occupancy_check_distance_idx = i;
}
}

// Visualize occupancy check distance if enabled
if (visualize_occupancy_check_distance_ &&
occupancy_check_dist_pub_->get_subscription_count() > 0)
{
auto occupancy_check_point = std::make_unique<geometry_msgs::msg::PoseStamped>();
occupancy_check_point->header.frame_id = costmap_ros_->getGlobalFrameID();
occupancy_check_point->header.stamp = now;
occupancy_check_point->pose.position.x = data.path.x(occupancy_check_distance_idx);
occupancy_check_point->pose.position.y = data.path.y(occupancy_check_distance_idx);
occupancy_check_point->pose.position.z = 0.0;
tf2::Quaternion quat;
quat.setRPY(0.0, 0.0, data.path.yaws(occupancy_check_distance_idx));
occupancy_check_point->pose.orientation = tf2::toMsg(quat);
occupancy_check_dist_pub_->publish(std::move(occupancy_check_point));
}

// Don't apply when dynamic obstacles are blocking significant proportions of the path
// up to occupancy_check_min_distance_
const float occupancy_check_distance_idx_flt = static_cast<float>(occupancy_check_distance_idx);
utils::setPathCostsIfNotSet(data, costmap_ros_);
std::vector<bool> & path_pts_valid = *data.path_pts_valid;
float invalid_ctr = 0.0f;
for (size_t i = 0; i < occupancy_check_distance_idx; i++) {
if (!path_pts_valid[i]) {invalid_ctr += 1.0f;}
if (invalid_ctr / occupancy_check_distance_idx_flt > max_path_occupancy_ratio_ &&
invalid_ctr > 2.0f)
{
return;
}
}

// Finish populating the path vector
Expand Down Expand Up @@ -145,6 +192,10 @@ void PathAlignCritic::score(CriticData & data)
path_pt = 0u;
float Tx_m1 = T_x(t, 0);
float Ty_m1 = T_y(t, 0);
// At each (strided) traj point, find the path point whose integrated arc-length distance along
// the path is closest to the trajectory point's integrated distance along the trajectory.
// if that path point is not in collision, compute the Euclidean distance between the matching
// path pt & traj pt the total cost is the average of those distances across the trajectory
for (int p = 1; p < traj_sampled_size; p++) {
const float Tx = T_x(t, p);
const float Ty = T_y(t, p);
Expand Down
2 changes: 2 additions & 0 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ void Optimizer::prepare(
{
state_.pose = robot_pose;
state_.speed = settings_.open_loop ? last_command_vel_ : robot_speed;
state_.robot_speed = robot_speed;

state_.local_path_length = nav2_util::geometry_utils::calculate_path_length(plan);
path_ = utils::toTensor(plan);
costs_.setZero(settings_.batch_size);
Expand Down
Loading