diff --git a/requirements-build.txt b/requirements-build.txt index 08772fc73..afbfbcc72 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,4 +1,4 @@ cython>=0.29 build>=0.7.0 auditwheel>=4 -numpy>=1.22.4 \ No newline at end of file +numpy>=1.21.6 \ No newline at end of file diff --git a/requirements-doc.txt b/requirements-doc.txt index eec7e0f33..29009919b 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -8,7 +8,7 @@ plantumlcli>=0.0.4 packaging sphinx-multiversion~=0.2.4 where~=1.0.2 -numpy>=1.22.4,<2 +numpy>=1.19,<2 easydict>=1.7,<2 scikit-learn>=0.24.2 nbsphinx>=0.8.8 diff --git a/requirements-test.txt b/requirements-test.txt index 559aa3f9a..c871bdea4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -11,6 +11,6 @@ pytest-benchmark~=3.4.0 testtools>=2 hbutils>=0.6.13 setuptools<=59.5.0 -numpy>=1.22.4 +numpy>=1.21.6 easydict>=1.7,<2 swig >= 4.1.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d9f43a19d..877f3f06d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ DI-engine>=0.4.7 gymnasium[atari] +moviepy numpy>=1.22.4 pympler minigrid diff --git a/zoo/metadrive/config/metadrive_sampled_efficientzero_config.py b/zoo/metadrive/config/metadrive_sampled_efficientzero_config.py new file mode 100644 index 000000000..a7742fb4f --- /dev/null +++ b/zoo/metadrive/config/metadrive_sampled_efficientzero_config.py @@ -0,0 +1,103 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +continuous_action_space = True +K = 20 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 64 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +metadrive_sampled_efficientzero_config = dict( + exp_name= + f'data_sez_ctree/sez_metadrive_old{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='MetaDrive', + continuous=True, + obs_shape = [5, 84, 84], + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + metadrive=dict( + use_render=False, + traffic_density=0.20, # Density of vehicles occupying the roads, range in [0,1] + map='XSOS', # Int or string: an easy way to fill map_config + horizon=4000, # Max step number + driving_reward=1.0, # Reward to encourage agent to move forward. + speed_reward=0.1, # Reward to encourage agent to drive at a high speed + use_lateral_reward=False, # reward for lane keeping + out_of_road_penalty=40.0, # Penalty to discourage driving out of road + crash_vehicle_penalty=40.0, # Penalty to discourage collision + decision_repeat=10, # Reciprocal of decision frequency + out_of_route_done=True, # Game over if driving out of road + ), + + ), + policy=dict( + model=dict( + observation_shape=[5, 84, 84], + action_space_size=2, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + model_type='conv', # options={'mlp', 'conv'} + lstm_hidden_size=128, + latent_state_dim=128, + downsample = True, + image_channel=5, + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + # NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper. + policy_entropy_loss_weight=5e-3, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2000), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +metadrive_sampled_efficientzero_config = EasyDict(metadrive_sampled_efficientzero_config) +main_config = metadrive_sampled_efficientzero_config + +metadrive_sampled_efficientzero_create_config = dict( + env=dict( + type='metadrive_lightzero', + import_names=['zoo.metadrive.env.metadrive_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + collector=dict( + type='episode_muzero', + get_train_sample=True, + import_names=['lzero.worker.muzero_collector'], + ) +) +metadrive_sampled_efficientzero_create_config = EasyDict(metadrive_sampled_efficientzero_create_config) +create_config = metadrive_sampled_efficientzero_create_config +if __name__ == "__main__": + + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/metadrive/env/drive_env.py b/zoo/metadrive/env/drive_env.py new file mode 100644 index 000000000..3ef57ea7e --- /dev/null +++ b/zoo/metadrive/env/drive_env.py @@ -0,0 +1,394 @@ +import copy +import gym +import numpy as np +from ditk import logging +from typing import Union, Dict, AnyStr, Tuple, Optional +from gym.envs.registration import register +from metadrive.manager.traffic_manager import TrafficMode +from metadrive.obs.top_down_obs_multi_channel import TopDownMultiChannel +from metadrive.constants import RENDER_MODE_NONE, DEFAULT_AGENT, REPLAY_DONE, TerminationState +from metadrive.envs.base_env import BaseEnv +from metadrive.component.map.base_map import BaseMap +from metadrive.component.map.pg_map import parse_map_config, MapGenerateMethod +from metadrive.component.pgblock.first_block import FirstPGBlock +from metadrive.component.vehicle.base_vehicle import BaseVehicle +from metadrive.utils import Config, merge_dicts, get_np_random, clip +from metadrive.envs.base_env import BASE_DEFAULT_CONFIG +from metadrive.component.road_network import Road +from metadrive.component.algorithm.blocks_prob_dist import PGBlockDistConfig + +METADRIVE_DEFAULT_CONFIG = dict( + # ===== Generalization ===== + start_seed=0, + environment_num=10, + decision_repeat=20, + block_dist_config=PGBlockDistConfig, + + # ===== Map Config ===== + map=3, # int or string: an easy way to fill map_config + random_lane_width=False, + random_lane_num=False, + map_config={ + BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_NUM, + BaseMap.GENERATE_CONFIG: None, # it can be a file path / block num / block ID sequence + BaseMap.LANE_WIDTH: 3.5, + BaseMap.LANE_NUM: 3, + "exit_length": 50, + }, + + # ===== Traffic ===== + traffic_density=0.1, + need_inverse_traffic=False, + traffic_mode=TrafficMode.Trigger, # "Respawn", "Trigger" + random_traffic=False, # Traffic is randomized at default. + traffic_vehicle_config=dict( + show_navi_mark=False, + show_dest_mark=False, + enable_reverse=False, + show_lidar=False, + show_lane_line_detector=False, + show_side_detector=False, + ), + + # ===== Object ===== + accident_prob=0., # accident may happen on each block with this probability, except multi-exits block + + # ===== Others ===== + use_AI_protector=False, + save_level=0.5, + is_multi_agent=False, + vehicle_config=dict(spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0)), + + # ===== Agent ===== + random_spawn_lane_index=True, + target_vehicle_configs={ + DEFAULT_AGENT: dict( + use_special_color=True, + spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0), + ) + }, + + # ===== Reward Scheme ===== + # See: https://github.com/decisionforce/metadrive/issues/283 + success_reward=10.0, + out_of_road_penalty=5.0, + crash_vehicle_penalty=5.0, + crash_object_penalty=5.0, + driving_reward=1.0, + speed_reward=0.1, + use_lateral_reward=False, + + # ===== Cost Scheme ===== + crash_vehicle_cost=1.0, + crash_object_cost=1.0, + out_of_road_cost=1.0, + + # ===== Termination Scheme ===== + out_of_route_done=False, + on_screen=False, + show_bird_view=False, +) + + +class MetaDrive(BaseEnv): + + @classmethod + def default_config(cls) -> "Config": + config = super(MetaDrive, cls).default_config() + config.update(METADRIVE_DEFAULT_CONFIG) + config.register_type("map", str, int) + config["map_config"].register_type("config", None) + return config + + def __init__(self, config: Union[dict, None] = None): + self.raw_cfg = config.metadrive + self.default_config_copy = Config(self.default_config(), unchangeable=True) + self.init_flag = False + + @property + def observation_space(self): + return gym.spaces.Box(0, 1, shape=(84, 84, 5), dtype=np.float32) + + @property + def action_space(self): + return gym.spaces.Box(-1, 1, shape=(2, ), dtype=np.float32) + + @property + def reward_space(self): + return gym.spaces.Box(-100, 100, shape=(1, ), dtype=np.float32) + + def seed(self, seed, dynamic_seed=False): + # TODO implement dynamic_seed mechanism + super().seed(seed) + + def reset(self): + if not self.init_flag: + super(MetaDrive, self).__init__(self.raw_cfg) + self.start_seed = self.config["start_seed"] + self.env_num = self.config["environment_num"] + self.init_flag = True + obs = super().reset() + return obs + + def _merge_extra_config(self, config: Union[dict, "Config"]) -> "Config": + config = self.default_config().update(config, allow_add_new_key=False) + if config["vehicle_config"]["lidar"]["distance"] > 50: + config["max_distance"] = config["vehicle_config"]["lidar"]["distance"] + return config + + def _post_process_config(self, config): + config = super(MetaDrive, self)._post_process_config(config) + if not config["rgb_clip"]: + logging.warning( + "You have set rgb_clip = False, which means the observation will be uint8 values in [0, 255]. " + "Please make sure you have parsed them later before feeding them to network!" + ) + config["map_config"] = parse_map_config( + easy_map_config=config["map"], new_map_config=config["map_config"], default_config=self.default_config_copy + ) + config["vehicle_config"]["rgb_clip"] = config["rgb_clip"] + config["vehicle_config"]["random_agent_model"] = config["random_agent_model"] + if config.get("gaussian_noise", 0) > 0: + assert config["vehicle_config"]["lidar"]["gaussian_noise"] == 0, "You already provide config!" + assert config["vehicle_config"]["side_detector"]["gaussian_noise"] == 0, "You already provide config!" + assert config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] == 0, "You already provide config!" + config["vehicle_config"]["lidar"]["gaussian_noise"] = config["gaussian_noise"] + config["vehicle_config"]["side_detector"]["gaussian_noise"] = config["gaussian_noise"] + config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] = config["gaussian_noise"] + if config.get("dropout_prob", 0) > 0: + assert config["vehicle_config"]["lidar"]["dropout_prob"] == 0, "You already provide config!" + assert config["vehicle_config"]["side_detector"]["dropout_prob"] == 0, "You already provide config!" + assert config["vehicle_config"]["lane_line_detector"]["dropout_prob"] == 0, "You already provide config!" + config["vehicle_config"]["lidar"]["dropout_prob"] = config["dropout_prob"] + config["vehicle_config"]["side_detector"]["dropout_prob"] = config["dropout_prob"] + config["vehicle_config"]["lane_line_detector"]["dropout_prob"] = config["dropout_prob"] + target_v_config = copy.deepcopy(config["vehicle_config"]) + if not config["is_multi_agent"]: + target_v_config.update(config["target_vehicle_configs"][DEFAULT_AGENT]) + config["target_vehicle_configs"][DEFAULT_AGENT] = target_v_config + return config + + def step(self, actions: Union[np.ndarray, Dict[AnyStr, np.ndarray]]): + actions = self._preprocess_actions(actions) + engine_info = self._step_simulator(actions) + o, r, d, i = self._get_step_return(actions, engine_info=engine_info) + return o, r, d, i + + def cost_function(self, vehicle_id: str): + vehicle = self.vehicles[vehicle_id] + step_info = dict() + step_info["cost"] = 0 + if self._is_out_of_road(vehicle): + step_info["cost"] = self.config["out_of_road_cost"] + elif vehicle.crash_vehicle: + step_info["cost"] = self.config["crash_vehicle_cost"] + elif vehicle.crash_object: + step_info["cost"] = self.config["crash_object_cost"] + return step_info['cost'], step_info + + def _is_out_of_road(self, vehicle): + ret = vehicle.on_yellow_continuous_line or vehicle.on_white_continuous_line or \ + (not vehicle.on_lane) or vehicle.crash_sidewalk + if self.config["out_of_route_done"]: + ret = ret or vehicle.out_of_route + return ret + + def done_function(self, vehicle_id: str): + vehicle = self.vehicles[vehicle_id] + done = False + done_info = { + TerminationState.CRASH_VEHICLE: False, + TerminationState.CRASH_OBJECT: False, + TerminationState.CRASH_BUILDING: False, + TerminationState.OUT_OF_ROAD: False, + TerminationState.SUCCESS: False, + TerminationState.MAX_STEP: False, + TerminationState.ENV_SEED: self.current_seed, + } + if self._is_arrive_destination(vehicle): + done = True + logging.info("Episode ended! Reason: arrive_dest.") + done_info[TerminationState.SUCCESS] = True + if self._is_out_of_road(vehicle): + done = True + logging.info("Episode ended! Reason: out_of_road.") + done_info[TerminationState.OUT_OF_ROAD] = True + if vehicle.crash_vehicle: + done = True + logging.info("Episode ended! Reason: crash vehicle ") + done_info[TerminationState.CRASH_VEHICLE] = True + if vehicle.crash_object: + done = True + done_info[TerminationState.CRASH_OBJECT] = True + logging.info("Episode ended! Reason: crash object ") + if vehicle.crash_building: + done = True + done_info[TerminationState.CRASH_BUILDING] = True + logging.info("Episode ended! Reason: crash building ") + if self.config["max_step_per_agent"] is not None and \ + self.episode_lengths[vehicle_id] >= self.config["max_step_per_agent"]: + done = True + done_info[TerminationState.MAX_STEP] = True + logging.info("Episode ended! Reason: max step ") + + if self.config["horizon"] is not None and \ + self.episode_lengths[vehicle_id] >= self.config["horizon"] and not self.is_multi_agent: + # single agent horizon has the same meaning as max_step_per_agent + done = True + done_info[TerminationState.MAX_STEP] = True + logging.info("Episode ended! Reason: max step ") + + done_info[TerminationState.CRASH] = ( + done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT] + or done_info[TerminationState.CRASH_BUILDING] + ) + # done_info['out_of_road'] = False + done_info['complete_ratio'] = clip(self.already_go_dist/ self.navi_distance + 0.05, 0.0, 1.0) + return done, done_info + + def reward_function(self, vehicle_id: str): + """ + Override this func to get a new reward function + :param vehicle_id: id of BaseVehicle + :return: reward + """ + vehicle = self.vehicles[vehicle_id] + step_info = dict() + if self._compute_navi_dist: + self.navi_distance = self.get_navigation_len(vehicle) + # if not self.config['const_episode_max_step']: + # self.episode_max_step = self.get_episode_max_step(self.navi_distance, self.avg_speed) + self._compute_navi_dist = False + + + # Reward for moving forward in current lane + if vehicle.lane in vehicle.navigation.current_ref_lanes: + current_lane = vehicle.lane + positive_road = 1 + else: + current_lane = vehicle.navigation.current_ref_lanes[0] + current_road = vehicle.navigation.current_road + positive_road = 1 if not current_road.is_negative_road() else -1 + long_last, _ = current_lane.local_coordinates(vehicle.last_position) + long_now, lateral_now = current_lane.local_coordinates(vehicle.position) + self.already_go_dist += (long_now - long_last) + + # reward for lane keeping, without it vehicle can learn to overtake but fail to keep in lane + if self.config["use_lateral_reward"]: + lateral_factor = clip(1 - 2 * abs(lateral_now) / vehicle.navigation.get_current_lane_width(), 0.0, 1.0) + else: + lateral_factor = 1.0 + + reward = 0.0 + reward += self.config["driving_reward"] * (long_now - long_last) * lateral_factor * positive_road + reward += self.config["speed_reward"] * (vehicle.speed / vehicle.max_speed) * positive_road + + step_info["step_reward"] = reward + + if self._is_arrive_destination(vehicle): + reward = +self.config["success_reward"] + elif self._is_out_of_road(vehicle): + reward = -self.config["out_of_road_penalty"] + elif vehicle.crash_vehicle: + reward = -self.config["crash_vehicle_penalty"] + elif vehicle.crash_object: + reward = -self.config["crash_object_penalty"] + return reward, step_info + + def _get_reset_return(self): + ret = {} + self.engine.after_step() + self._compute_navi_dist = True + self.already_go_dist = 0 + for v_id, v in self.vehicles.items(): + self.observations[v_id].reset(self, v) + ret[v_id] = self.observations[v_id].observe(v) + return ret if self.is_multi_agent else self._wrap_as_single_agent(ret) + + def switch_to_third_person_view(self) -> (str, BaseVehicle): + if self.main_camera is None: + return + self.main_camera.reset() + if self.config["prefer_track_agent"] is not None and self.config["prefer_track_agent"] in self.vehicles.keys(): + new_v = self.vehicles[self.config["prefer_track_agent"]] + current_track_vehicle = new_v + else: + if self.main_camera.is_bird_view_camera(): + current_track_vehicle = self.current_track_vehicle + else: + vehicles = list(self.engine.agents.values()) + if len(vehicles) <= 1: + return + if self.current_track_vehicle in vehicles: + vehicles.remove(self.current_track_vehicle) + new_v = get_np_random().choice(vehicles) + current_track_vehicle = new_v + self.main_camera.track(current_track_vehicle) + return + + def switch_to_top_down_view(self): + self.main_camera.stop_track() + + def setup_engine(self): + super(MetaDrive, self).setup_engine() + self.engine.accept("b", self.switch_to_top_down_view) + self.engine.accept("q", self.switch_to_third_person_view) + from metadrive.manager.traffic_manager import TrafficManager + from metadrive.manager.map_manager import MapManager + self.engine.register_manager("map_manager", MapManager()) + self.engine.register_manager("traffic_manager", TrafficManager()) + + def _is_arrive_destination(self, vehicle): + long, lat = vehicle.navigation.final_lane.local_coordinates(vehicle.position) + flag = (vehicle.navigation.final_lane.length - 5 < long < vehicle.navigation.final_lane.length + 5) and ( + vehicle.navigation.get_current_lane_width() / 2 >= lat >= + (0.5 - vehicle.navigation.get_current_lane_num()) * vehicle.navigation.get_current_lane_width() + ) + return flag + + def _reset_global_seed(self, force_seed=None): + """ + Current seed is set to force seed if force_seed is not None. + Otherwise, current seed is randomly generated. + """ + current_seed = force_seed if force_seed is not None else \ + get_np_random(self._DEBUG_RANDOM_SEED).randint(self.start_seed, self.start_seed + self.env_num) + self.seed(current_seed) + + def _get_observations(self): + return {DEFAULT_AGENT: self.get_single_observation(self.config["vehicle_config"])} + + def get_single_observation(self, _=None): + return TopDownMultiChannel( + self.config["vehicle_config"], + self.config["on_screen"], + self.config["rgb_clip"], + frame_stack=3, + post_stack=10, + frame_skip=1, + resolution=(84, 84), + max_distance=36, + ) + + def clone(self, caller: str): + cfg = copy.deepcopy(self.raw_cfg) + return MetaDrive(cfg) + + def get_navigation_len(self, vehicle): + checkpoints = vehicle.navigation.checkpoints + road_network = vehicle.navigation.map.road_network + total_dist = 0 + assert len(checkpoints) >=2 + for check_num in range(0, len(checkpoints)-1): + front_node = checkpoints[check_num] + end_node = checkpoints[check_num+1] + cur_lanes = road_network.graph[front_node][end_node] + target_lane_num = int(len(cur_lanes) / 2) + target_lane = cur_lanes[target_lane_num] + target_lane_length = target_lane.length + total_dist += target_lane_length + + if hasattr(vehicle.navigation, 'u_turn_case'): + if vehicle.navigation.u_turn_case is True: + total_dist += 35 + return total_dist \ No newline at end of file diff --git a/zoo/metadrive/env/metadrive_env.py b/zoo/metadrive/env/metadrive_env.py new file mode 100644 index 000000000..7438c36d2 --- /dev/null +++ b/zoo/metadrive/env/metadrive_env.py @@ -0,0 +1,222 @@ +from typing import Any, Dict, Optional +from easydict import EasyDict +import matplotlib.pyplot as plt +import gymnasium as gym +import copy +import numpy as np +from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep +from ding.torch_utils.data_helper import to_ndarray +from ding.utils.default_helper import deep_merge_dicts +from ding.utils import ENV_REGISTRY + +from zoo.metadrive.env.drive_env import MetaDrive + +def draw_multi_channels_top_down_observation(obs, show_time=0.5): + """ + Overview: + Displays a multi-channel top-down observation from an autonomous vehicle. + Auguments: + - obs (:obj:`numpy.ndarray`): A 3D NumPy array of shape (height, width, 5) representing the observation data, + where the last dimension corresponds to the five distinct channels. + - show_time (:obj:`float`): The duration in seconds for which the observation image will be displayed. Defaults to 0.5 seconds. + """ + # Validate that there are exactly five channels in the observation data. + num_channels = obs.shape[-1] + assert num_channels == 5, "The observation data must have exactly 5 channels." + + # Define the names for each of the five channels. + channel_names = [ + "Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1", + "Neighbor at step t-2" + ] + + # Create a figure with a subplot for each channel. + fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80) + + # Initialize a counter to track the current channel index. + count = 0 + + # Define a callback function to close the figure after the specified show_time. + def close_event(): + plt.close(fig) # Explicitly close the figure referenced by 'fig'. + + # Create a timer that triggers the close_event after the specified duration. + timer = fig.canvas.new_timer(interval=show_time * 1000) + timer.add_callback(close_event) + + # Iterate over each channel and display its observation data. + for i, name in enumerate(channel_names): + count += 1 + ax = axs[i] # Retrieve the subplot for the current channel. + ax.imshow(obs[..., i], cmap="bone") # Display the observation data using a bone colormap. + ax.set_xticks([]) # Hide the x-axis ticks. + ax.set_yticks([]) # Hide the y-axis ticks. + ax.set_title(name) # Set the title for the subplot based on the channel name. + + # Set a title for the entire figure that summarizes the content. + fig.suptitle("Multi-channels Top-down Observation", fontsize='large') + + # Start the timer to initiate the automatic closing of the figure. + timer.start() + + # Display the figure with the multi-channel observation data. + plt.show() + + # Close the figure after it has been displayed for the specified duration. + plt.close() # Explicitly close the figure to ensure it is properly closed. + +@ENV_REGISTRY.register('metadrive_lightzero') +class MetaDriveEnv(BaseEnv): + """ + Overview: + MetaDrive environment in LightZero. + """ + config = dict( + # (bool) Whether to use continuous action space + continuous=True, + # replay_path (str or None): The path to save the replay video. If None, the replay will not be saved. + # Only effective when env_manager.type is 'base'. + replay_path=None, + # (bool) Whether to scale action into [-2, 2] + act_scale=True, + + ) + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: dict = {}) -> None: + """ + Overview: + Initialize the environment with a configuration dictionary. Set up spaces for observations, actions, and rewards. + Arguments: + - cfg (:obj:`dict`): Configuration dict. + """ + # Initialize a raw env + self._cfg = cfg + self._env = MetaDrive(self._cfg) + self._init_flag = True + self._reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, )) + self._action_space = self._env.action_space + self._observation_space = self._env.observation_space + + # bird view + self.show_bird_view = False + + def reset(self, *args, **kwargs) -> Any: + """ + Overview: + Reset the environment and return the initial observation. + Returns: + - metadrive_obs (:obj:`dict`): An observation dict for the MetaDrive env which includes ``observation``, ``action_mask``, ``to_play``. + """ + obs = self._env.reset(*args, **kwargs) + obs = to_ndarray(obs, dtype=np.float32) + if isinstance(obs, np.ndarray) and len(obs.shape) == 3: + # obs = obs.transpose((2, 0, 1)) + obs = obs + elif isinstance(obs, dict): + vehicle_state = obs['vehicle_state'] + # birdview = obs['birdview'].transpose((2, 0, 1)) + birdview = obs['birdview'] + obs = {'vehicle_state': vehicle_state, 'birdview': birdview} + self._eval_episode_return = 0.0 + self._arrive_dest = False + self._observation_space = self._env.observation_space + + metadrive_obs = {} + metadrive_obs['observation'] = obs + metadrive_obs['action_mask'] = None + metadrive_obs['to_play'] = -1 + return metadrive_obs + + def step(self, action: np.ndarray = None) -> BaseEnvTimestep: + """ + Overview: + Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into + that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep`` + namedtuple defined in DI-engine. It will also convert actions, observations and reward into + ``np.ndarray``. In origin MetaDrive setting the action can be None, but in our pipeline an action is always performed to the environment. + Arguments: + - action (:obj:`np.ndarray`): The action to be performed in the environment. + Returns: + - timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag, + and info dictionary. + """ + action = to_ndarray(action) + obs, rew, done, info = self._env.step(action) + if self.show_bird_view: + draw_multi_channels_top_down_observation(obs, show_time=0.5) + self._eval_episode_return += rew + obs = to_ndarray(obs, dtype=np.float32) + if isinstance(obs, np.ndarray) and len(obs.shape) == 3: + # obs = obs.transpose((2, 0, 1)) + obs = obs + elif isinstance(obs, dict): + vehicle_state = obs['vehicle_state'] + # birdview = obs['birdview'].transpose((2, 0, 1)) + birdview = obs['birdview'] + obs = {'vehicle_state': vehicle_state, 'birdview': birdview} + rew = to_ndarray([rew], dtype=np.float32) + if done: + info['eval_episode_return'] = self._eval_episode_return + metadrive_obs = {} + metadrive_obs['observation'] = obs + metadrive_obs['action_mask'] = None + metadrive_obs['to_play'] = -1 + return BaseEnvTimestep(metadrive_obs, rew, done, info) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment's random number generator. Can handle both static and dynamic seeding. + """ + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + self._env = gym.wrappers.Monitor(self._env, self._replay_path, video_callable=lambda episode_id: True, force=True) + + def render(self): + self._env.render() + + @property + def observation_space(self) -> gym.spaces.Space: + """ + Property to access the observation space of the environment. + """ + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + """ + Property to access the action space of the environment. + """ + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + """ + Property to access the reward space of the environment. + """ + return self._reward_space + + def close(self) -> None: + """ + Close the environment, and set the initialization flag to False. + """ + if self._init_flag: + self._env.close() + self._init_flag = False + + def __repr__(self) -> str: + return repr(self._env) + + def clone(self): + cfg = copy.deepcopy(self._cfg) + return MetaDriveEnv(cfg) \ No newline at end of file