Skip to content

Commit

Permalink
Merge pull request #139 from tryolabs/motion-estimator
Browse files Browse the repository at this point in the history
Estimate camera motion using the mode of the Optical Flow
  • Loading branch information
javiber authored Sep 1, 2022
2 parents 7dcfe9c + ab32eb8 commit 97cb9d8
Show file tree
Hide file tree
Showing 13 changed files with 846 additions and 15 deletions.
6 changes: 6 additions & 0 deletions demos/camera_motion/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FROM ultralytics/yolov5:v6.2

# Install Norfair
RUN pip install git+https://github.com/tryolabs/norfair.git@master#egg=norfair

WORKDIR /demo/src/
45 changes: 45 additions & 0 deletions demos/camera_motion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Moving Camera Demo

In this example, we show how to estimate the camera movement in Norfair.

What's the motivation for estimating camera movement?

- When the camera moves, the apparent movement of the objects can be quite erratic and confuse the tracker; by estimating the camera movement we can stabilize the objects and improve tracking.
- By estimating the position of objects in a fixed reference we can correctly calculate their trajectory. This can help you if you are trying to determine when objects enter a predefined zone on the scene or trying to draw their trajectory

Keep in mind that for estimating the camera movement we rely on a static background, if the scene is too chaotic with a lot of movement the estimation will lose accuracy. Nevertheless, even when the estimation is incorrect it will not hurt the tracking.

## First Example - Translation

This method only works for camera pans and tilts.

![Pan and Tilt](/docs/pan_tilt.png)

The following video shows on the left we lost the person 4 times while on the right we were able to maintain the tracked object throughout the video:

![camera_stabilization](/docs/camera_stabilization.gif)

> videos generated using command `python demo.py --transformation none --draw-objects --track-boxes --id-size 1.8 --distance-threshold 200 --save video.mp4` and `python demo.py --transformation translation --fixed-camera-scale 2 --draw-objects --track-boxes --id-size 1.8 --distance-threshold 200 --save video.mp4`
## Second Example - Homographies

This method can work with any camera movement, this includes pan, tilt, rotation, traveling in any direction, and zoom.

In the following video, the correct trajectory of the players is drawn even as the camera moves:

![soccer](/docs/soccer.gif)

> video generated using command `python demo.py --transformation homography --draw-paths --path-history 150 --distance-threshold 200 --track-boxes --max-points=900 --min-distance=14 --save --model yolov5x --hit-counter-max 3 video.mp4` on a snippet of this [video](https://www.youtube.com/watch?v=CGFgHjeEkbY&t=1200s)

## Setup

Build and run the Docker container with ./run_gpu.sh.

Copy a video to the src folder.

Within the container, run with the default parameters:

`python demo.py <video>.mp4`

For additional settings, you may display the instructions using `python demo.py --help`.
1 change: 1 addition & 0 deletions demos/camera_motion/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
yolov5==6.1.8
8 changes: 8 additions & 0 deletions demos/camera_motion/run_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env -S bash -e
docker build . -t norfair-camera-motion
docker run -it --rm \
--gpus all \
--shm-size=1gb \
-v `realpath .`:/demo \
norfair-camera-motion \
bash
284 changes: 284 additions & 0 deletions demos/camera_motion/src/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import argparse
from functools import partial

import numpy as np
import torch

from norfair import (
AbsolutePaths,
Detection,
FixedCamera,
Tracker,
Video,
draw_absolute_grid,
draw_tracked_boxes,
)
from norfair.camera_motion import (
HomographyTransformationGetter,
MotionEstimator,
TranslationTransformationGetter,
)
from norfair.drawing import draw_tracked_objects


def yolo_detections_to_norfair_detections(yolo_detections, track_boxes):
norfair_detections = []
boxes = []
detections_as_xyxy = yolo_detections.xyxy[0]
for detection_as_xyxy in detections_as_xyxy:
detection_as_xyxy = detection_as_xyxy.cpu().numpy()
bbox = np.array(
[
[detection_as_xyxy[0].item(), detection_as_xyxy[1].item()],
[detection_as_xyxy[2].item(), detection_as_xyxy[3].item()],
]
)
boxes.append(bbox)
if track_boxes:
points = bbox
scores = np.array([detection_as_xyxy[4], detection_as_xyxy[4]])
else:
points = bbox.mean(axis=0, keepdims=True)
scores = detection_as_xyxy[[4]]

norfair_detections.append(
Detection(points=points, scores=scores, label=detection_as_xyxy[-1].item())
)

return norfair_detections, boxes


def run():
parser = argparse.ArgumentParser(description="Track objects in a video.")
parser.add_argument("files", type=str, nargs="+", help="Video files to process")
parser.add_argument(
"--model",
type=str,
default="yolov5n",
help="YOLO model to use, possible values are yolov5n, yolov5s, yolov5m, yolov5l, yolov5x",
)
parser.add_argument(
"--confidence-threshold",
type=float,
help="Confidence threshold of detections",
default=0.15,
)
parser.add_argument(
"--distance-threshold",
type=float,
default=0.8,
help="Max distance to consider when matching detections and tracked objects",
)
parser.add_argument(
"--initialization-delay",
type=float,
default=3,
help="Min detections needed to start the tracked object",
)
parser.add_argument(
"--track-boxes",
dest="track_boxes",
action="store_true",
help="Pass it to track bounding boxes instead of just the centroids",
)
parser.add_argument(
"--hit-counter-max",
type=int,
default=30,
help="Max iteration the tracked object is kept after when there are no detections",
)
parser.add_argument(
"--iou-threshold", type=float, help="Iou threshold for detector", default=0.15
)
parser.add_argument(
"--image-size", type=int, help="Size of the images for detector", default=480
)
parser.add_argument(
"--classes", type=int, nargs="+", default=[0], help="Classes to track"
)
parser.add_argument(
"--transformation",
default="homography",
help="Type of transformation, possible values are homography, translation, none",
)
parser.add_argument(
"--max-points",
type=int,
default=500,
help="Max points sampled to calculate camera motion",
)
parser.add_argument(
"--min-distance",
type=float,
default=7,
help="Min distance between points sampled to calculate camera motion",
)
parser.add_argument(
"--no-mask-detections",
dest="mask_detections",
action="store_false",
default=True,
help="By default we don't sample regions where objects were detected when estimating camera motion. Pass this flag to disable this behavior",
)
parser.add_argument(
"--save",
dest="save",
action="store_true",
help="Pass this flag to save the video instead of showing the frames",
)
parser.add_argument(
"--output-name",
default=None,
help="Name of the output file",
)
parser.add_argument(
"--downsample-ratio",
type=int,
default=1,
help="Downsample ratio when showing frames",
)
parser.add_argument(
"--fixed-camera-scale",
type=float,
default=0,
help="Scale of the fixed camera, set to 0 to disable. Note that this only works for translation",
)
parser.add_argument(
"--draw-absolute-grid",
dest="absolute_grid",
action="store_true",
help="Pass this flag to draw absolute grid for reference",
)
parser.add_argument(
"--draw-objects",
dest="draw_objects",
action="store_true",
help="Pass this flag to draw tracked object as points or as boxes if --track-boxes is used.",
)
parser.add_argument(
"--draw-paths",
dest="draw_paths",
action="store_true",
help="Pass this flag to draw the paths of the objects (SLOW)",
)
parser.add_argument(
"--path-history",
type=int,
default=20,
help="Length of the paths",
)
parser.add_argument(
"--id-size",
type=float,
default=None,
help="Size multiplier of the ids when drawing. Thikness will addapt to size",
)
parser.add_argument(
"--draw-flow",
dest="draw_flow",
action="store_true",
help="Pass this flag to draw the optical flow of the selected points",
)

args = parser.parse_args()

model = torch.hub.load("ultralytics/yolov5", args.model)
model.conf_threshold = 0
model.iou_threshold = args.iou_threshold
model.image_size = args.image_size
model.classes = args.classes

use_fixed_camera = args.fixed_camera_scale > 0
tracked_objects = []
# Process Videos
for input_path in args.files:
if args.transformation == "homography":
transformations_getter = HomographyTransformationGetter()
elif args.transformation == "translation":
transformations_getter = TranslationTransformationGetter()
elif args.transformation == "none":
transformations_getter = None
else:
raise ValueError(f"invalid transformation {args.transformation}")
if transformations_getter is not None:
motion_estimator = MotionEstimator(
max_points=args.max_points,
min_distance=args.min_distance,
transformations_getter=transformations_getter,
draw_flow=args.draw_flow
)
else:
motion_estimator = None

if use_fixed_camera:
fixed_camera = FixedCamera(scale=args.fixed_camera_scale)

if args.draw_paths:
path_drawer = AbsolutePaths(max_history=args.path_history, thickness=2)

video = Video(input_path=input_path)
show_or_write = (
video.write
if args.save
else partial(video.show, downsample_ratio=args.downsample_ratio)
)

tracker = Tracker(
distance_function="frobenius",
detection_threshold=args.confidence_threshold,
distance_threshold=args.distance_threshold,
initialization_delay=args.initialization_delay,
hit_counter_max=args.hit_counter_max,
)
for frame in video:
detections = model(frame)
detections, boxes = yolo_detections_to_norfair_detections(detections, args.track_boxes)

mask = None
if args.mask_detections:
# create a mask of ones
mask = np.ones(frame.shape[:2], frame.dtype)
# set to 0 all detections
for b in boxes:
i = b.astype(int)
mask[i[0, 1] : i[1, 1], i[0, 0] : i[1, 0]] = 0
if args.track_boxes:
for obj in tracked_objects:
i = obj.estimate.astype(int)
mask[i[0, 1] : i[1, 1], i[0, 0] : i[1, 0]] = 0

if motion_estimator is None:
coord_transformations = None
else:
coord_transformations = motion_estimator.update(frame, mask)

tracked_objects = tracker.update(
detections=detections, coord_transformations=coord_transformations
)

if args.draw_objects:
draw_tracked_objects(
frame,
tracked_objects,
id_size=args.id_size,
id_thickness=None
if args.id_size is None
else int(args.id_size * 2),
)

if args.absolute_grid:
draw_absolute_grid(frame, coord_transformations)

if args.draw_paths:
frame = path_drawer.draw(
frame, tracked_objects, coord_transform=coord_transformations
)

if use_fixed_camera:
frame = fixed_camera.adjust_frame(frame, coord_transformations)

show_or_write(frame)


if __name__ == "__main__":
run()
Binary file added docs/camera_stabilization.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pan_tilt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/soccer.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 97cb9d8

Please sign in to comment.