-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bubble_env
contribution
#1537
Changes from all commits
5f88b32
609ccd1
9a5b15d
9e748c1
f0f8787
5658ee3
b30dda4
0dcecb0
94ff36d
32d9758
6845ee4
cf1d8d5
93382ab
0df120f
a5791a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ phase: track1 | |
eval_episodes: 50 | ||
seed: 42 | ||
scenarios: [] | ||
bubble_env_evaluation_seeds: [] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
phase: track2 | ||
eval_episodes: 50 | ||
seed: 42 | ||
scenarios: [] | ||
scenarios: [] | ||
bubble_env_evaluation_seeds: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ eval_episodes: 1 | |
seed: 42 | ||
scenarios: | ||
- 1_to_2lane_left_turn_c | ||
bubble_env_evaluation_seeds: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import copy | ||
from typing import Any, Dict, Iterable, Tuple | ||
from typing import Any, Dict, Iterable, List, Tuple | ||
|
||
import gym | ||
|
||
|
@@ -10,7 +10,10 @@ def __init__(self): | |
self._agent_names = None | ||
|
||
def __call__(self, **kwargs): | ||
self._data = copy.deepcopy(dict(**kwargs)) | ||
try: | ||
self._data = copy.deepcopy(dict(**kwargs)) | ||
except RecursionError: | ||
self._data = copy.copy(dict(**kwargs)) | ||
Comment on lines
+13
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am unsure how the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Adaickalavan To be clear, I think this is fine to update at a slightly later date because it is only useful to us. |
||
|
||
@property | ||
def data(self): | ||
|
@@ -26,10 +29,10 @@ def agent_names(self, names: Iterable[str]): | |
|
||
|
||
class CopyData(gym.Wrapper): | ||
def __init__(self, env: gym.Env, datastore: DataStore): | ||
def __init__(self, env: gym.Env, agent_ids: List[str], datastore: DataStore): | ||
super(CopyData, self).__init__(env) | ||
self._datastore = datastore | ||
self._datastore.agent_names = list(env.agent_specs.keys()) | ||
self._datastore.agent_names = agent_ids | ||
|
||
def step( | ||
self, action: Dict[str, Any] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import subprocess | ||
import sys | ||
from pathlib import Path | ||
from typing import Any, Dict | ||
from typing import Any, Dict, List, Optional | ||
|
||
logger = logging.getLogger(__file__) | ||
|
||
|
@@ -15,6 +15,7 @@ | |
"eval_episodes", | ||
"seed", | ||
"scenarios", | ||
"bubble_env_evaluation_seeds", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using short variable names. Here, we could use |
||
} | ||
_DEFAULT_EVALUATION_CONFIG = dict( | ||
phase="track1", | ||
|
@@ -30,6 +31,7 @@ | |
"3lane_cut_in", | ||
"3lane_overtake", | ||
], | ||
bubble_env_evaluation_seeds=[6], | ||
) | ||
_SUBMISSION_CONFIG_KEYS = { | ||
"img_meters", | ||
|
@@ -41,36 +43,24 @@ | |
) | ||
|
||
|
||
def make_env( | ||
config: Dict[str, Any], | ||
scenario: str, | ||
def wrap_env( | ||
env, | ||
agent_ids: List[str], | ||
datastore: "DataStore", | ||
wrappers=[], | ||
): | ||
"""Make environment. | ||
|
||
Args: | ||
config (Dict[str, Any]): A dictionary of config parameters. | ||
scenario (str): Scenario | ||
env (gym.Env): The environment to wrap. | ||
wrappers (List[gym.Wrapper], optional): Sequence of gym environment wrappers. | ||
Defaults to empty list []. | ||
|
||
Returns: | ||
gym.Env: Environment corresponding to the `scenario`. | ||
gym.Env: Environment wrapped for evaluation. | ||
""" | ||
|
||
# Create environment | ||
env = gym.make( | ||
"smarts.env:multi-scenario-v0", | ||
scenario=scenario, | ||
img_meters=int(config["img_meters"]), | ||
img_pixels=int(config["img_pixels"]), | ||
action_space="TargetPose", | ||
sumo_headless=True, | ||
) | ||
|
||
# Make a copy of original info. | ||
env = CopyData(env, datastore) | ||
env = CopyData(env, agent_ids, datastore) | ||
# Disallow modification of attributes starting with "_" by external users. | ||
env = gym.Wrapper(env) | ||
|
||
|
@@ -82,35 +72,60 @@ def make_env( | |
|
||
|
||
def evaluate(config): | ||
scenarios = config["scenarios"] | ||
|
||
base_scenarios = config["scenarios"] | ||
shared_configs = dict( | ||
action_space="TargetPose", | ||
img_meters=int(config["img_meters"]), | ||
img_pixels=int(config["img_pixels"]), | ||
sumo_headless=True, | ||
) | ||
# Make evaluation environments. | ||
envs_eval = {} | ||
for scen in scenarios: | ||
for scenario in base_scenarios: | ||
env = gym.make( | ||
"smarts.env:multi-scenario-v0", scenario=scenario, **shared_configs | ||
) | ||
datastore = DataStore() | ||
envs_eval[f"{scenario}"] = ( | ||
wrap_env( | ||
env, | ||
agent_ids=list(env.agent_specs.keys()), | ||
datastore=datastore, | ||
wrappers=submitted_wrappers(), | ||
), | ||
datastore, | ||
None, | ||
) | ||
|
||
bonus_eval_seeds = config.get("bubble_env_evaluation_seeds", []) | ||
for seed in bonus_eval_seeds: | ||
env = gym.make("bubble_env_contrib:bubble_env-v0", **shared_configs) | ||
datastore = DataStore() | ||
envs_eval[f"{scen}"] = ( | ||
make_env( | ||
config=config, | ||
scenario=scen, | ||
envs_eval[f"bubble_env_{seed}"] = ( | ||
wrap_env( | ||
env, | ||
agent_ids=list(env.agent_ids), | ||
datastore=datastore, | ||
wrappers=submitted_wrappers(), | ||
), | ||
datastore, | ||
seed, | ||
) | ||
|
||
# Instantiate submitted policy. | ||
policy = Policy() | ||
|
||
# Evaluate model for each scenario | ||
score = Score() | ||
for index, (env_name, (env, datastore)) in enumerate(envs_eval.items()): | ||
for index, (env_name, (env, datastore, seed)) in enumerate(envs_eval.items()): | ||
logger.info(f"\n{index}. Evaluating env {env_name}.\n") | ||
counts, costs = run( | ||
env=env, | ||
datastore=datastore, | ||
env_name=env_name, | ||
policy=policy, | ||
config=config, | ||
seed=seed, | ||
) | ||
score.add(counts, costs) | ||
|
||
|
@@ -119,18 +134,25 @@ def evaluate(config): | |
logger.info("\nFinished evaluating.\n") | ||
|
||
# Close all environments | ||
for env, _ in envs_eval.values(): | ||
for env, _, _ in envs_eval.values(): | ||
env.close() | ||
|
||
return rank | ||
|
||
|
||
def run( | ||
env, datastore: "DataStore", env_name: str, policy: "Policy", config: Dict[str, Any] | ||
env, | ||
datastore: "DataStore", | ||
env_name: str, | ||
policy: "Policy", | ||
config: Dict[str, Any], | ||
seed: Optional[int], | ||
): | ||
# Instantiate metric for score calculation. | ||
metric = Metric(env_name=env_name, agent_names=datastore.agent_names) | ||
|
||
# Ensure deterministic seeding | ||
env.seed((seed or 0) + config["seed"]) | ||
for _ in range(config["eval_episodes"]): | ||
observations = env.reset() | ||
dones = {"__all__": False} | ||
|
@@ -203,6 +225,7 @@ def to_codalab_scores_string(rank) -> str: | |
"pip", | ||
"install", | ||
"smarts[camera-obs] @ git+https://github.com/huawei-noah/SMARTS.git@comp-1", | ||
"bubble_env @ git+https://bitbucket.org/malban/bubble_env.git@master", | ||
] | ||
) | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", req_file]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,6 @@ def step( | |
obs, reward, done, info = self.env.step(action) | ||
|
||
for agent_id in info.keys(): | ||
info[agent_id]["is_success"] = bool(info[agent_id]["score"]) | ||
info[agent_id]["is_success"] = bool(info[agent_id].get("score", True)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Why is the |
||
|
||
return obs, reward, done, info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
bubble_env_evaluation_seeds
is an empty list, then the additional bubble_env will not be run during Codalab evaluation (i.e., not included in competition scoring and ranking). It only runs during local evaluation by participants as the default config setsbubble_env_evaluation_seeds=[6]
.Was this the intended outcome?