Skip to content

Commit

Permalink
cmd/yolo-detection: add object detection example using YOLOv8
Browse files Browse the repository at this point in the history
Signed-off-by: deadprogram <[email protected]>
  • Loading branch information
deadprogram committed Jun 26, 2024
1 parent e099e20 commit e7a6ac7
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
6 changes: 6 additions & 0 deletions cmd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,9 @@ https://github.com/hybridgroup/gocv/blob/release/cmd/version/main.go
This example demonstrates a couple different uses of the XPhoto module. It can use the GrayworldWB class with BalanceWhite image to save an image file on disk. It can also use the Inpaint functions with inpaint algorithms type to save an image file on disk.

https://github.com/hybridgroup/gocv/blob/release/cmd/xphoto/main.go

## YOLO DNN Detection

Use the YOLOv8 Deep Neural Network to detect and track objects or faces.

https://github.com/hybridgroup/gocv/blob/release/cmd/yolo-detection/main.go
194 changes: 194 additions & 0 deletions cmd/yolo-detection/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// What it does:
//
// This example uses the YOLOv8 deep neural network to perform object detection.
//
// Download the ONNX model file from the following notebook:
//
// https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/tutorial.ipynb
//
// How to run:
//
// go run ./cmd/yolo-detection/ [videosource] [modelfile] ([backend] [device])
package main

import (
"fmt"
"image"
"image/color"
"os"

"gocv.io/x/gocv"
)

func main() {
if len(os.Args) < 3 {
fmt.Println("How to run:\nyolo-detection [videosource] [modelfile] ([backend] [device])")
return
}

// parse args
deviceID := os.Args[1]
model := os.Args[2]
backend := gocv.NetBackendDefault
if len(os.Args) > 3 {
backend = gocv.ParseNetBackend(os.Args[3])
}

target := gocv.NetTargetCPU
if len(os.Args) > 4 {
target = gocv.ParseNetTarget(os.Args[4])
}

// open capture device
webcam, err := gocv.OpenVideoCapture(deviceID)
if err != nil {
fmt.Printf("Error opening video capture device: %v\n", deviceID)
return
}
defer webcam.Close()

window := gocv.NewWindow("YOLO Detection")
defer window.Close()

img := gocv.NewMat()
defer img.Close()

// open DNN object tracking model
net := gocv.ReadNetFromONNX(model)
if net.Empty() {
fmt.Printf("Error reading network model from : %v\n", model)
return
}
defer net.Close()
net.SetPreferableBackend(gocv.NetBackendType(backend))
net.SetPreferableTarget(gocv.NetTargetType(target))

outputNames := getOutputNames(&net)
if len(outputNames) == 0 {
fmt.Println("Error reading output layer names")
return
}

fmt.Printf("Start reading device: %v\n", deviceID)

for {
if ok := webcam.Read(&img); !ok {
fmt.Printf("Device closed: %v\n", deviceID)
return
}
if img.Empty() {
continue
}

detect(&net, &img, outputNames)

window.IMShow(img)
if window.WaitKey(1) >= 0 {
break
}
}
}

var (
ratio = 0.003921568627
mean = gocv.NewScalar(0, 0, 0, 0)
swapRGB = false
padValue = gocv.NewScalar(144.0, 0, 0, 0)

scoreThreshold float32 = 0.5
nmsThreshold float32 = 0.4
)

func detect(net *gocv.Net, src *gocv.Mat, outputNames []string) {
params := gocv.NewImageToBlobParams(ratio, image.Pt(640, 640), mean, swapRGB, gocv.MatTypeCV32F, gocv.DataLayoutNCHW, gocv.PaddingModeLetterbox, padValue)
blob := gocv.BlobFromImageWithParams(*src, params)
defer blob.Close()

// feed the blob into the detector
net.SetInput(blob, "")

// run a forward pass thru the network
probs := net.ForwardLayers(outputNames)
defer func() {
for _, prob := range probs {
prob.Close()
}
}()

boxes, confidences, classIds := performDetection(probs)
if len(boxes) == 0 {
fmt.Println("No classes detected")
return
}

iboxes := params.BlobRectsToImageRects(boxes, image.Pt(src.Cols(), src.Rows()))
indices := gocv.NMSBoxes(iboxes, confidences, scoreThreshold, nmsThreshold)
drawRects(src, iboxes, classes, classIds, indices)
}

func getOutputNames(net *gocv.Net) []string {
var outputLayers []string
for _, i := range net.GetUnconnectedOutLayers() {
layer := net.GetLayer(i)
layerName := layer.GetName()
if layerName != "_input" {
outputLayers = append(outputLayers, layerName)
}
}

return outputLayers
}

func performDetection(outs []gocv.Mat) ([]image.Rectangle, []float32, []int) {
var classIds []int
var confidences []float32
var boxes []image.Rectangle

// needed for yolov8
gocv.TransposeND(outs[0], []int{0, 2, 1}, &outs[0])

for _, out := range outs {
out = out.Reshape(1, out.Size()[1])

data, _ := out.DataPtrFloat32()
for i := 0; i < out.Rows(); i, data = i+1, data[out.Cols():] {

scoresCol := out.RowRange(i, i+1)

scores := scoresCol.ColRange(4, out.Cols())
_, confidence, _, classIDPoint := gocv.MinMaxLoc(scores)

if confidence > 0.5 {
centerX := data[0]
centerY := data[1]
width := data[2]
height := data[3]

left := centerX - width/2
top := centerY - height/2
right := centerX + width/2
bottom := centerY + height/2
classIds = append(classIds, classIDPoint.X)
confidences = append(confidences, float32(confidence))

boxes = append(boxes, image.Rect(int(left), int(top), int(right), int(bottom)))
}
}
}

return boxes, confidences, classIds
}

func drawRects(img *gocv.Mat, boxes []image.Rectangle, classes []string, classIds []int, indices []int) []string {
var detectClass []string
for _, idx := range indices {
if idx == 0 {
continue
}
gocv.Rectangle(img, image.Rect(boxes[idx].Min.X, boxes[idx].Min.Y, boxes[idx].Max.X, boxes[idx].Max.Y), color.RGBA{0, 255, 0, 0}, 2)
gocv.PutText(img, classes[classIds[idx]], image.Point{boxes[idx].Min.X, boxes[idx].Min.Y - 10}, gocv.FontHersheyPlain, 0.6, color.RGBA{0, 255, 0, 0}, 1)
detectClass = append(detectClass, classes[classIds[idx]])
}

return detectClass
}
14 changes: 14 additions & 0 deletions cmd/yolo-detection/yolo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package main

// Array of YOLOv8 class labels
var classes = []string{
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
}

0 comments on commit e7a6ac7

Please sign in to comment.