-
-
Notifications
You must be signed in to change notification settings - Fork 855
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cmd/yolo-detection: add object detection example using YOLOv8
Signed-off-by: deadprogram <[email protected]>
- Loading branch information
1 parent
e099e20
commit e7a6ac7
Showing
3 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
// What it does: | ||
// | ||
// This example uses the YOLOv8 deep neural network to perform object detection. | ||
// | ||
// Download the ONNX model file from the following notebook: | ||
// | ||
// https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/tutorial.ipynb | ||
// | ||
// How to run: | ||
// | ||
// go run ./cmd/yolo-detection/ [videosource] [modelfile] ([backend] [device]) | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"image" | ||
"image/color" | ||
"os" | ||
|
||
"gocv.io/x/gocv" | ||
) | ||
|
||
func main() { | ||
if len(os.Args) < 3 { | ||
fmt.Println("How to run:\nyolo-detection [videosource] [modelfile] ([backend] [device])") | ||
return | ||
} | ||
|
||
// parse args | ||
deviceID := os.Args[1] | ||
model := os.Args[2] | ||
backend := gocv.NetBackendDefault | ||
if len(os.Args) > 3 { | ||
backend = gocv.ParseNetBackend(os.Args[3]) | ||
} | ||
|
||
target := gocv.NetTargetCPU | ||
if len(os.Args) > 4 { | ||
target = gocv.ParseNetTarget(os.Args[4]) | ||
} | ||
|
||
// open capture device | ||
webcam, err := gocv.OpenVideoCapture(deviceID) | ||
if err != nil { | ||
fmt.Printf("Error opening video capture device: %v\n", deviceID) | ||
return | ||
} | ||
defer webcam.Close() | ||
|
||
window := gocv.NewWindow("YOLO Detection") | ||
defer window.Close() | ||
|
||
img := gocv.NewMat() | ||
defer img.Close() | ||
|
||
// open DNN object tracking model | ||
net := gocv.ReadNetFromONNX(model) | ||
if net.Empty() { | ||
fmt.Printf("Error reading network model from : %v\n", model) | ||
return | ||
} | ||
defer net.Close() | ||
net.SetPreferableBackend(gocv.NetBackendType(backend)) | ||
net.SetPreferableTarget(gocv.NetTargetType(target)) | ||
|
||
outputNames := getOutputNames(&net) | ||
if len(outputNames) == 0 { | ||
fmt.Println("Error reading output layer names") | ||
return | ||
} | ||
|
||
fmt.Printf("Start reading device: %v\n", deviceID) | ||
|
||
for { | ||
if ok := webcam.Read(&img); !ok { | ||
fmt.Printf("Device closed: %v\n", deviceID) | ||
return | ||
} | ||
if img.Empty() { | ||
continue | ||
} | ||
|
||
detect(&net, &img, outputNames) | ||
|
||
window.IMShow(img) | ||
if window.WaitKey(1) >= 0 { | ||
break | ||
} | ||
} | ||
} | ||
|
||
var ( | ||
ratio = 0.003921568627 | ||
mean = gocv.NewScalar(0, 0, 0, 0) | ||
swapRGB = false | ||
padValue = gocv.NewScalar(144.0, 0, 0, 0) | ||
|
||
scoreThreshold float32 = 0.5 | ||
nmsThreshold float32 = 0.4 | ||
) | ||
|
||
func detect(net *gocv.Net, src *gocv.Mat, outputNames []string) { | ||
params := gocv.NewImageToBlobParams(ratio, image.Pt(640, 640), mean, swapRGB, gocv.MatTypeCV32F, gocv.DataLayoutNCHW, gocv.PaddingModeLetterbox, padValue) | ||
blob := gocv.BlobFromImageWithParams(*src, params) | ||
defer blob.Close() | ||
|
||
// feed the blob into the detector | ||
net.SetInput(blob, "") | ||
|
||
// run a forward pass thru the network | ||
probs := net.ForwardLayers(outputNames) | ||
defer func() { | ||
for _, prob := range probs { | ||
prob.Close() | ||
} | ||
}() | ||
|
||
boxes, confidences, classIds := performDetection(probs) | ||
if len(boxes) == 0 { | ||
fmt.Println("No classes detected") | ||
return | ||
} | ||
|
||
iboxes := params.BlobRectsToImageRects(boxes, image.Pt(src.Cols(), src.Rows())) | ||
indices := gocv.NMSBoxes(iboxes, confidences, scoreThreshold, nmsThreshold) | ||
drawRects(src, iboxes, classes, classIds, indices) | ||
} | ||
|
||
func getOutputNames(net *gocv.Net) []string { | ||
var outputLayers []string | ||
for _, i := range net.GetUnconnectedOutLayers() { | ||
layer := net.GetLayer(i) | ||
layerName := layer.GetName() | ||
if layerName != "_input" { | ||
outputLayers = append(outputLayers, layerName) | ||
} | ||
} | ||
|
||
return outputLayers | ||
} | ||
|
||
func performDetection(outs []gocv.Mat) ([]image.Rectangle, []float32, []int) { | ||
var classIds []int | ||
var confidences []float32 | ||
var boxes []image.Rectangle | ||
|
||
// needed for yolov8 | ||
gocv.TransposeND(outs[0], []int{0, 2, 1}, &outs[0]) | ||
|
||
for _, out := range outs { | ||
out = out.Reshape(1, out.Size()[1]) | ||
|
||
data, _ := out.DataPtrFloat32() | ||
for i := 0; i < out.Rows(); i, data = i+1, data[out.Cols():] { | ||
|
||
scoresCol := out.RowRange(i, i+1) | ||
|
||
scores := scoresCol.ColRange(4, out.Cols()) | ||
_, confidence, _, classIDPoint := gocv.MinMaxLoc(scores) | ||
|
||
if confidence > 0.5 { | ||
centerX := data[0] | ||
centerY := data[1] | ||
width := data[2] | ||
height := data[3] | ||
|
||
left := centerX - width/2 | ||
top := centerY - height/2 | ||
right := centerX + width/2 | ||
bottom := centerY + height/2 | ||
classIds = append(classIds, classIDPoint.X) | ||
confidences = append(confidences, float32(confidence)) | ||
|
||
boxes = append(boxes, image.Rect(int(left), int(top), int(right), int(bottom))) | ||
} | ||
} | ||
} | ||
|
||
return boxes, confidences, classIds | ||
} | ||
|
||
func drawRects(img *gocv.Mat, boxes []image.Rectangle, classes []string, classIds []int, indices []int) []string { | ||
var detectClass []string | ||
for _, idx := range indices { | ||
if idx == 0 { | ||
continue | ||
} | ||
gocv.Rectangle(img, image.Rect(boxes[idx].Min.X, boxes[idx].Min.Y, boxes[idx].Max.X, boxes[idx].Max.Y), color.RGBA{0, 255, 0, 0}, 2) | ||
gocv.PutText(img, classes[classIds[idx]], image.Point{boxes[idx].Min.X, boxes[idx].Min.Y - 10}, gocv.FontHersheyPlain, 0.6, color.RGBA{0, 255, 0, 0}, 1) | ||
detectClass = append(detectClass, classes[classIds[idx]]) | ||
} | ||
|
||
return detectClass | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package main | ||
|
||
// Array of YOLOv8 class labels | ||
var classes = []string{ | ||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", | ||
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", | ||
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", | ||
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", | ||
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", | ||
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", | ||
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", | ||
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", | ||
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", | ||
} |