Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use-multiprocessing and verbose flags to Unity dataset processing. #1789

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions scripts/unity_dataset_processing/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,15 @@ def decimate(
for i in range(importer.mesh_count):
mesh = importer.mesh(i)

# Transform the mesh to its max scale in the scene. For quantized meshes
# this expands the position attribute to a floating-point Vector3.
scaled_mesh = meshtools.transform3d(
mesh, Matrix4.scaling(max_mesh_scaling.get(i, Vector3(1.0)))
)

# Calculate total triangle area of the *transformed* mesh. You might want
# to fiddle with this heuristics, another option is calculating the mesh
# AABB but that won't do the right thing for planar meshes.
if simplify:
# Transform the mesh to its max scale in the scene. For quantized meshes
# this expands the position attribute to a floating-point Vector3.
scaled_mesh = meshtools.transform3d(
mesh, Matrix4.scaling(max_mesh_scaling.get(i, Vector3(1.0)))
)
if not scaled_mesh.is_indexed:
converter.end_file()
importer.close()
Expand Down
66 changes: 54 additions & 12 deletions scripts/unity_dataset_processing/unity_dataset_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import json
import os
import time
from multiprocessing import Manager, Pool
from pathlib import Path
from typing import Callable, List, Set
Expand All @@ -16,6 +17,7 @@
OUTPUT_DIR = "data/hitl_simplified/data/"
OMIT_BLACK_LIST = False
OMIT_GRAY_LIST = False
PROCESS_COUNT = os.cpu_count()


class Job:
Expand All @@ -24,6 +26,13 @@ class Job:
simplify: bool


class Config:
# Increase logging verbosity.
verbose: bool = False
# Activate multiprocessing. Disable when debugging.
use_multiprocessing: bool = False


def file_is_scene_config(filepath: str) -> bool:
"""
Return whether or not the file is an scene_instance.json
Expand Down Expand Up @@ -93,8 +102,14 @@ def get_model_ids_from_scene_instance_json(filepath: str) -> List[str]:
return model_ids


def validate_jobs(jobs: List[Job]):
for job in jobs:
assert Path(job.source_path).exists
assert job.dest_path != None and job.dest_path != ""


def process_model(args):
job, counter, lock, total_models = args
job, counter, lock, total_models, verbose = args

if os.path.isfile(job.dest_path):
print(f"Skipping: {job.source_path}")
Expand All @@ -110,19 +125,21 @@ def process_model(args):
source_tris, target_tris, simplified_tris = decimate.decimate(
inputFile=job.source_path,
outputFile=job.dest_path,
quiet=True,
quiet=not verbose,
verbose=verbose,
sloppy=False,
simplify=job.simplify,
)
except Exception:
try:
print(
f"Unable to decimate: {job.source_path}. Trying passthrough (no decimation)."
f"Unable to decimate: {job.source_path}. Trying without decimation."
)
source_tris, target_tris, simplified_tris = decimate.decimate(
inputFile=job.source_path,
outputFile=job.dest_path,
quiet=True,
quiet=not verbose,
verbose=verbose,
simplify=False,
)
except Exception:
Expand All @@ -135,8 +152,8 @@ def process_model(args):
)

result = {
"source_tris": source_tris,
"simplified_tris": simplified_tris,
"source_tris": str(source_tris),
"simplified_tris": str(simplified_tris),
"source_path": job.source_path,
"status": "ok",
}
Expand All @@ -161,7 +178,8 @@ def process_model(args):
return result


def simplify_models(jobs: List[Job]):
def simplify_models(jobs: List[Job], config: Config):
start_time = time.time()
total_source_tris = 0
total_simplified_tris = 0
black_list = []
Expand All @@ -171,6 +189,9 @@ def simplify_models(jobs: List[Job]):
total_skipped = 0
total_error = 0

validate_jobs(jobs)
total_models = len(jobs)

# Initialize counter and lock
manager = Manager()
counter = manager.Value("i", 0)
Expand All @@ -179,13 +200,14 @@ def simplify_models(jobs: List[Job]):
total_models = len(jobs)

# Pair up the model paths with the counter and lock
args_lists = [(job, counter, lock, total_models) for job in jobs]
args_lists = [
(job, counter, lock, total_models, config.verbose) for job in jobs
]

results = []

use_multiprocessing = False # total_models > 6
if use_multiprocessing:
max_processes = 6
if config.use_multiprocessing:
max_processes = PROCESS_COUNT
with Pool(processes=min(max_processes, total_models)) as pool:
results = list(pool.map(process_model, args_lists))
else:
Expand Down Expand Up @@ -226,6 +248,11 @@ def simplify_models(jobs: List[Job]):
print(" " + item + ",")
print("]")

elapsed_time = time.time() - start_time
print(f"Elapsed time (s): {elapsed_time}")
if not config.use_multiprocessing:
print("Add --use-multiprocessing to speed-up processing.")


def find_model_paths_in_scenes(hssd_hab_root_dir, scene_ids) -> List[str]:
model_filepaths: Set[str] = set()
Expand Down Expand Up @@ -279,8 +306,23 @@ def main():
type=str,
help="one or more scene ids",
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Increase logging verbosity.",
)
parser.add_argument(
"--use-multiprocessing",
action="store_true",
default=False,
help="Enable multiprocessing.",
)

args = parser.parse_args()
config = Config()
config.verbose = args.verbose
config.use_multiprocessing = args.use_multiprocessing

# Force input paths to have a trailing slash
if args.hssd_hab_root_dir[-1] != "/":
Expand Down Expand Up @@ -352,7 +394,7 @@ def main():
job.simplify = False
jobs.append(job)

simplify_models(jobs)
simplify_models(jobs, config)


if __name__ == "__main__":
Expand Down