Skip to content

Commit

Permalink
planning: add learning_model_sample scenario execution
Browse files Browse the repository at this point in the history
  • Loading branch information
jmtao authored and xiaoxq committed May 13, 2020
1 parent c56f6c7 commit 9b1c8c9
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 28 deletions.
7 changes: 4 additions & 3 deletions modules/planning/common/planning_gflags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,10 @@ DEFINE_string(
"a list of source files or directories for offline mode. "
"The items need to be separated by colon ':'. ");
DEFINE_int32(planning_offline_mode, 0,
"0: no learning"
"1: online mode, no dump file"
"2: dump learning_data to <record file>.<n>.bin");
"0: no learning "
"1: online learning, no dump file "
"2: offline learning. read record files and dump learning_data "
" to <record file>.<n>.bin");
DEFINE_int32(learning_data_obstacle_history_time_sec, 3.0,
"time sec (second) of history trajectory points for a obstacle");
DEFINE_int32(learning_data_frame_num_per_file, 100,
Expand Down
3 changes: 3 additions & 0 deletions modules/planning/conf/planning.conf
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
--enable_multi_thread_in_dp_st_graph
--use_osqp_optimizer_for_reference_line

# 0-no learning; 1-online learning; 3-offline: read records and dump learning data
--planning_offline_mode=0

# --smoother_config_filename=/apollo/modules/planning/conf/spiral_smoother_config.pb.txt
# --smoother_config_filename=/apollo/modules/planning/conf/qp_spline_smoother_config.pb.txt
--smoother_config_filename=/apollo/modules/planning/conf/discrete_points_smoother_config.pb.txt
Expand Down
55 changes: 37 additions & 18 deletions modules/planning/scenarios/scenario_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ void ScenarioManager::RegisterScenarios() {
ACHECK(Scenario::LoadConfig(FLAGS_scenario_emergency_stop_config_file,
&config_map_[ScenarioConfig::EMERGENCY_STOP]));

// learning model
ACHECK(Scenario::LoadConfig(
FLAGS_scenario_learning_model_sample_config_file,
&config_map_[ScenarioConfig::LEARNING_MODEL_SAMPLE]));

// park_and_go
ACHECK(Scenario::LoadConfig(FLAGS_scenario_park_and_go_config_file,
&config_map_[ScenarioConfig::PARK_AND_GO]));
Expand All @@ -158,11 +163,6 @@ void ScenarioManager::RegisterScenarios() {
FLAGS_scenario_stop_sign_unprotected_config_file,
&config_map_[ScenarioConfig::STOP_SIGN_UNPROTECTED]));

// learning model
ACHECK(Scenario::LoadConfig(
FLAGS_scenario_learning_model_sample_config_file,
&config_map_[ScenarioConfig::LEARNING_MODEL_SAMPLE]));

// traffic_light
ACHECK(Scenario::LoadConfig(
FLAGS_scenario_traffic_light_protected_config_file,
Expand Down Expand Up @@ -790,19 +790,46 @@ void ScenarioManager::Update(const common::TrajectoryPoint& ego_point,

Observe(frame);

ScenarioDispatch(ego_point, frame);
ScenarioDispatch(frame);
}

void ScenarioManager::ScenarioDispatch(const common::TrajectoryPoint& ego_point,
const Frame& frame) {
void ScenarioManager::ScenarioDispatch(const Frame& frame) {
ACHECK(!frame.reference_line_info().empty());

ScenarioConfig::ScenarioType scenario_type;
if (FLAGS_planning_offline_mode == 1) {
scenario_type = ScenarioDispatchLearning();
} else {
scenario_type = ScenarioDispatchNonLearning(frame);
}

ADEBUG << "select scenario: "
<< ScenarioConfig::ScenarioType_Name(scenario_type);

// update PlanningContext
UpdatePlanningContext(frame, scenario_type);

if (current_scenario_->scenario_type() != scenario_type) {
current_scenario_ = CreateScenario(scenario_type);
}
}

ScenarioConfig::ScenarioType ScenarioManager::ScenarioDispatchLearning() {
////////////////////////////////////////
// learning model scenario
ScenarioConfig::ScenarioType scenario_type =
ScenarioConfig::LEARNING_MODEL_SAMPLE;
return scenario_type;
}

ScenarioConfig::ScenarioType ScenarioManager::ScenarioDispatchNonLearning(
const Frame& frame) {
////////////////////////////////////////
// default: LANE_FOLLOW
ScenarioConfig::ScenarioType scenario_type = default_scenario_type_;

////////////////////////////////////////
// Pad Msg Scenario
// Pad Msg scenario
scenario_type = SelectPadMsgScenario(frame);

if (scenario_type == default_scenario_type_) {
Expand Down Expand Up @@ -860,15 +887,7 @@ void ScenarioManager::ScenarioDispatch(const common::TrajectoryPoint& ego_point,
scenario_type = SelectValetParkingScenario(frame);
}

ADEBUG << "select scenario: "
<< ScenarioConfig::ScenarioType_Name(scenario_type);

// update PlanningContext
UpdatePlanningContext(frame, scenario_type);

if (current_scenario_->scenario_type() != scenario_type) {
current_scenario_ = CreateScenario(scenario_type);
}
return scenario_type;
}

bool ScenarioManager::IsBareIntersectionScenario(
Expand Down
5 changes: 3 additions & 2 deletions modules/planning/scenarios/scenario_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ class ScenarioManager final {

ScenarioConfig::ScenarioType SelectParkAndGoScenario(const Frame& frame);

void ScenarioDispatch(const common::TrajectoryPoint& ego_point,
const Frame& frame);
void ScenarioDispatch(const Frame& frame);
ScenarioConfig::ScenarioType ScenarioDispatchLearning();
ScenarioConfig::ScenarioType ScenarioDispatchNonLearning(const Frame& frame);

bool IsBareIntersectionScenario(
const ScenarioConfig::ScenarioType& scenario_type);
Expand Down
1 change: 1 addition & 0 deletions modules/planning/tasks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cc_library(
"//modules/planning/tasks/deciders/speed_bounds_decider",
"//modules/planning/tasks/deciders/speed_decider",
"//modules/planning/tasks/deciders/st_bounds_decider",
"//modules/planning/tasks/learning_model:learning_based_task",
"//modules/planning/tasks/optimizers/open_space_trajectory_generation:open_space_trajectory_provider",
"//modules/planning/tasks/optimizers/open_space_trajectory_partition",
"//modules/planning/tasks/optimizers/path_time_heuristic:path_time_heuristic_optimizer",
Expand Down
10 changes: 8 additions & 2 deletions modules/planning/tasks/learning_model/learning_based_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,20 @@ using apollo::common::ErrorCode;
using apollo::common::Status;
using apollo::common::TrajectoryPoint;

LearningBasedTask::LearningBasedTask(const TaskConfig &config)
: Task(config), device_(torch::kCPU) {
ACHECK(config.has_learning_based_task_config());
}

Status LearningBasedTask::Execute(Frame *frame,
ReferenceLineInfo *reference_line_info) {
Task::Execute(frame, reference_line_info);
return Process(frame);
}

Status LearningBasedTask::Process(Frame *frame) {
const auto model_file = config_.model_file();
auto& config = config_.learning_based_task_config();
const auto model_file = config.model_file();
if (apollo::cyber::common::PathExists(model_file)) {
try {
model_ = torch::jit::load(model_file, device_);
Expand All @@ -47,7 +53,7 @@ Status LearningBasedTask::Process(Frame *frame) {
"learning based task model file not exist");
}
}
input_feature_num_ = config_.input_feature_num();
input_feature_num_ = config.input_feature_num();

std::vector<torch::jit::IValue> input_features;
ExtractFeatures(frame, &input_features);
Expand Down
4 changes: 1 addition & 3 deletions modules/planning/tasks/learning_model/learning_based_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ namespace planning {

class LearningBasedTask : public Task {
public:
explicit LearningBasedTask(const TaskConfig &config)
: Task(config), device_(torch::kCPU) {}
explicit LearningBasedTask(const TaskConfig &config);

apollo::common::Status Execute(
Frame *frame, ReferenceLineInfo *reference_line_info) override;
Expand All @@ -47,7 +46,6 @@ class LearningBasedTask : public Task {
bool InferenceModel(const std::vector<torch::jit::IValue> &input_features,
Frame* frame);
private:
LearningBasedTaskConfig config_;
torch::Device device_;
torch::jit::script::Module model_;
int input_feature_num_ = 0;
Expand Down
8 changes: 8 additions & 0 deletions modules/planning/tasks/task_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "modules/planning/tasks/deciders/speed_bounds_decider/speed_bounds_decider.h"
#include "modules/planning/tasks/deciders/speed_decider/speed_decider.h"
#include "modules/planning/tasks/deciders/st_bounds_decider/st_bounds_decider.h"
#include "modules/planning/tasks/learning_model/learning_based_task.h"
#include "modules/planning/tasks/optimizers/open_space_trajectory_generation/open_space_trajectory_provider.h"
#include "modules/planning/tasks/optimizers/open_space_trajectory_partition/open_space_trajectory_partition.h"
#include "modules/planning/tasks/optimizers/path_time_heuristic/path_time_heuristic_optimizer.h"
Expand Down Expand Up @@ -152,6 +153,13 @@ void TaskFactory::Init(const PlanningConfig& config) {
[](const TaskConfig& config) -> Task* {
return new PiecewiseJerkSpeedOptimizer(config);
});
///////////////////////////
// other tasks
task_factory_.Register(TaskConfig::LEARNING_BASED_TASK,
[](const TaskConfig& config) -> Task* {
return new LearningBasedTask(config);
});

for (const auto& default_task_config : config.default_task_config()) {
default_task_configs_[default_task_config.task_type()] =
default_task_config;
Expand Down

0 comments on commit 9b1c8c9

Please sign in to comment.