Skip to content

Commit 3ab362f

Browse files
committed
Initial commit
0 parents  commit 3ab362f

File tree

180 files changed

+19386
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

180 files changed

+19386
-0
lines changed

.gitignore

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# Mac
7+
.DS_Store
8+
9+
# data
10+
data/
11+
12+
# test notebook
13+
*.ipynb
14+
15+
# C extensions
16+
*.so
17+
18+
# Distribution / packaging
19+
.Python
20+
build/
21+
develop-eggs/
22+
dist/
23+
downloads/
24+
eggs/
25+
.eggs/
26+
lib/
27+
lib64/
28+
parts/
29+
sdist/
30+
var/
31+
wheels/
32+
pip-wheel-metadata/
33+
share/python-wheels/
34+
*.egg-info/
35+
.installed.cfg
36+
*.egg
37+
MANIFEST
38+
39+
# PyInstaller
40+
# Usually these files are written by a python script from a template
41+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
42+
*.manifest
43+
*.spec
44+
45+
# Installer logs
46+
pip-log.txt
47+
pip-delete-this-directory.txt
48+
49+
# Unit test / coverage reports
50+
htmlcov/
51+
.tox/
52+
.nox/
53+
.coverage
54+
.coverage.*
55+
.cache
56+
nosetests.xml
57+
coverage.xml
58+
*.cover
59+
*.py,cover
60+
.hypothesis/
61+
.pytest_cache/
62+
cover/
63+
64+
# Translations
65+
*.mo
66+
*.pot
67+
68+
# Django stuff:
69+
*.log
70+
local_settings.py
71+
db.sqlite3
72+
db.sqlite3-journal
73+
74+
# Flask stuff:
75+
instance/
76+
.webassets-cache
77+
78+
# Scrapy stuff:
79+
.scrapy
80+
81+
# Sphinx documentation
82+
docs/_build/
83+
84+
# PyBuilder
85+
.pybuilder/
86+
target/
87+
88+
# Jupyter Notebook
89+
.ipynb_checkpoints
90+
91+
# IPython
92+
profile_default/
93+
ipython_config.py
94+
95+
# pyenv
96+
# For a library or package, you might want to ignore these files since the code is
97+
# intended to run in multiple environments; otherwise, check them in:
98+
# .python-version
99+
100+
# pipenv
101+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
103+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
104+
# install all needed dependencies.
105+
#Pipfile.lock
106+
107+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
108+
__pypackages__/
109+
110+
# Celery stuff
111+
celerybeat-schedule
112+
celerybeat.pid
113+
114+
# SageMath parsed files
115+
*.sage.py
116+
117+
# Environments
118+
.env
119+
.venv
120+
venv/
121+
ENV/
122+
env.bak/
123+
venv.bak/
124+
125+
# Spyder project settings
126+
.spyderproject
127+
.spyproject
128+
129+
# Rope project settings
130+
.ropeproject
131+
132+
# mkdocs documentation
133+
/site
134+
135+
# mypy
136+
.mypy_cache/
137+
.dmypy.json
138+
dmypy.json
139+
140+
# Pyre type checker
141+
.pyre/
142+
143+
# pytype static type analyzer
144+
.pytype/
145+
146+
# Cython debug symbols
147+
cython_debug/
148+
149+
# static files generated from Django application using `collectstatic`
150+
media
151+
static

README.md

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SLAMuZero
2+
3+
## acme-more-mcts
4+
5+
## Installation
6+
fork [acme-more-mcts](https://github.com/bwfbowen/acme-more-mcts) and run following command from the main directory(where `setup.py` is located):
7+
```sh
8+
pip install .[jax,tf,testing,envs]
9+
```
10+
11+
12+
Habitat is also needed:
13+
```sh
14+
pip install habitat-sim
15+
pip install habitat-api
16+
```
17+
18+
## Setup
19+
The project requires datasets in a `data` folder in the following format (same as habitat-api):
20+
```
21+
SLAMuZero/
22+
data/
23+
scene_datasets/
24+
gibson/
25+
Adrian.glb
26+
Adrian.navmesh
27+
...
28+
datasets/
29+
pointnav/
30+
gibson/
31+
v1/
32+
train/
33+
val/
34+
...
35+
```
36+
Please download the data using the instructions here: https://github.com/facebookresearch/habitat-api#data
37+
38+
## Getting started
39+
To run the code:
40+
```python
41+
python run_acme.py
42+
```
43+
44+
## Visualization
45+
Pass `--print_images 1` to plot trajectory
46+
```sh
47+
python run_acme.py --print_images 1
48+
```
49+
50+
And use `draw.py` to generate `.gif`
51+
```sh
52+
python draw.py --pic_dir PATH_TO_DUMP
53+
```
54+
55+
<img src="./assets/demo_s4.gif" alt="drawing" width="300"/>

assets/demo_s4.gif

2.22 MB
Loading

draw.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from PIL import Image
2+
import os
3+
import re
4+
5+
from absl import app
6+
from absl import flags
7+
8+
flags.DEFINE_string('pic_dir', "~/tmp/dump/nts_eval_img/episodes/1/1", 'task config.')
9+
flags.DEFINE_string('output_file', '~/tmp/dump/nav.gif', 'output file')
10+
flags.DEFINE_integer('frame_duration', 50, 'frame duration')
11+
flags.DEFINE_string('pic_suffix', '.png', 'pic suffix')
12+
flags.DEFINE_list('max_size', [100,100], 'max size')
13+
flags.DEFINE_boolean('optimize', True, 'whether to optimize')
14+
15+
FLAGS = flags.FLAGS
16+
17+
18+
def sorting_key(filename):
19+
match = re.search(r"(\d+)-(\d+)-Vis-(\d+)\.png$", filename)
20+
if match:
21+
return int(match.group(3))
22+
return 0 # Default if the pattern is not found
23+
24+
25+
def resize_image(image, max_size):
26+
return image.resize(max_size, Image.LANCZOS)
27+
28+
29+
def generate_gif_from_pics(
30+
pic_dir: str,
31+
output_file: str,
32+
frame_duration: int = 500,
33+
pic_suffix: str = '.png',
34+
max_size: tuple = (100, 100),
35+
sort_fn: callable = sorting_key,
36+
optimize: bool = True,
37+
):
38+
# Sort the file names
39+
pic_dir = os.path.expanduser(pic_dir)
40+
file_names = sorted([os.path.join(pic_dir, f) for f in os.listdir(pic_dir) if f.endswith(pic_suffix)], key=sort_fn)
41+
42+
# Create a list to hold the images
43+
images = []
44+
45+
# Load each file, resize if necessary, and append to images list
46+
for file_name in file_names:
47+
with Image.open(file_name) as img:
48+
img_resized = resize_image(img, tuple(map(int, max_size)))
49+
images.append(img_resized.copy()) # Copy to ensure the file is not left open
50+
51+
# Create and save the GIF
52+
gif_path = os.path.expanduser(output_file)
53+
images[0].save(gif_path, save_all=True, append_images=images[1:], optimize=optimize, duration=frame_duration, loop=0)
54+
55+
56+
def main(_):
57+
generate_gif_from_pics(FLAGS.pic_dir, FLAGS.output_file, FLAGS.frame_duration, FLAGS.pic_suffix, FLAGS.max_size, optimize=FLAGS.optimize)
58+
59+
60+
if __name__ == '__main__':
61+
app.run(main)

env/__init__.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
from .habitat import construct_envs, Exploration_Env
4+
5+
6+
def make_vec_envs(args):
7+
envs = construct_envs(args)
8+
envs = VecPyTorch(envs, args.device)
9+
return envs
10+
11+
12+
# Adapted from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/envs.py#L159
13+
class VecPyTorch():
14+
15+
def __init__(self, venv, device):
16+
self.venv = venv
17+
self.num_envs = venv.num_envs
18+
self.observation_space = venv.observation_space
19+
self.action_space = venv.action_space
20+
self.device = device
21+
22+
def reset(self):
23+
obs, info = self.venv.reset()
24+
obs = torch.from_numpy(obs).float().to(self.device)
25+
return obs, info
26+
27+
def step_async(self, actions):
28+
actions = actions.cpu().numpy()
29+
self.venv.step_async(actions)
30+
31+
def step_wait(self):
32+
obs, reward, done, info = self.venv.step_wait()
33+
obs = torch.from_numpy(obs).float().to(self.device)
34+
reward = torch.from_numpy(reward).float()
35+
return obs, reward, done, info
36+
37+
def step(self, actions):
38+
actions = actions.cpu().numpy()
39+
obs, reward, done, info = self.venv.step(actions)
40+
obs = torch.from_numpy(obs).float().to(self.device)
41+
reward = torch.from_numpy(reward).float()
42+
return obs, reward, done, info
43+
44+
def get_rewards(self, inputs):
45+
reward = self.venv.get_rewards(inputs)
46+
reward = torch.from_numpy(reward).float()
47+
return reward
48+
49+
def get_short_term_goal(self, inputs):
50+
stg = self.venv.get_short_term_goal(inputs)
51+
stg = torch.from_numpy(stg).float()
52+
return stg
53+
54+
def close(self):
55+
return self.venv.close()

0 commit comments

Comments
 (0)