Skip to content

Commit

Permalink
[pytorch] Adds Yolov11 model to model zoo (#3516)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Nov 12, 2024
1 parent 911c35a commit 890d980
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ public class OrtModelZoo extends ModelZoo {

OrtModelZoo() {
addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"));
addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolo11n-seg", "0.0.1"));
addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-base-plus", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-small", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo11n", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo5s", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1"));
addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolo11n-pose", "0.0.1"));
addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolov8n-pose", "0.0.1"));
addModel(REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1"));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/instance_segmentation",
"groupId": "ai.djl.onnxruntime",
"artifactId": "yolo11n-seg",
"name": "yolo11n seg",
"description": "yolo11n Instance Segmentation",
"website": "http://www.djl.ai/engines/onnxruntime/onnxruntime-engine",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolo11n-seg",
"properties": {
},
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"threshold": 0.25,
"translatorFactory": "ai.djl.modality.cv.translator.YoloSegmentationTranslatorFactory"
},
"files": {
"model": {
"uri": "0.0.1/yolo11n-seg.zip",
"sha1Hash": "981707febd20985b4131410498c36696bb5d442e",
"name": "",
"size": 10181964
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/object_detection",
"groupId": "ai.djl.onnxruntime",
"artifactId": "yolo11n",
"name": "yolo11n",
"description": "Yolo11n Object Detection",
"website": "http://www.djl.ai/engines/onnxruntime/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolo11n",
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"rescale": true,
"optApplyRatio": true,
"threshold": 0.6,
"translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory"
},
"files": {
"model": {
"uri": "0.0.1/yolo11n.zip",
"name": "",
"sha1Hash": "e2b16662393be2ae7ae41c641ebdb1c919675878",
"size": 9318872
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/pose_estimation",
"groupId": "ai.djl.onnxruntime",
"artifactId": "yolo11n-pose",
"name": "yolo11n pose",
"description": "yolo11n Pose Estimation",
"website": "http://www.djl.ai/engines/onnxruntime/onnxruntime-engine",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolo11n-pose",
"properties": {
},
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"threshold": 0.25,
"translatorFactory": "ai.djl.modality.cv.translator.YoloPoseTranslatorFactory"
},
"files": {
"model": {
"uri": "0.0.1/yolo11n-pose.zip",
"sha1Hash": "3206acb87f1a8a69eeca017639376b7549ae1e92",
"name": "",
"size": 10176054
}
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ public class PtModelZoo extends ModelZoo {
addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"));
addModel(
REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1"));
addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolo11n-seg", "0.0.1"));
addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.1"));
addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo11n", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1"));
addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1"));
addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolo11n-pose", "0.0.1"));
addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolov8n-pose", "0.0.1"));
addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1"));
addModel(REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/instance_segmentation",
"groupId": "ai.djl.pytorch",
"artifactId": "yolo11n-seg",
"name": "yolo11n seg",
"description": "yolo11n Instance Segmentation",
"website": "http://www.djl.ai/engines/pytorch/pytorch-model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolo11n-seg",
"properties": {
},
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"threshold": 0.25,
"translatorFactory": "ai.djl.modality.cv.translator.YoloSegmentationTranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/yolo11n-seg.zip",
"sha1Hash": "22a1d6a33fb9a5da738d04fc82c2faab8104ecc7",
"name": "",
"size": 10380298
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/object_detection",
"groupId": "ai.djl.pytorch",
"artifactId": "yolo11n",
"name": "yolo11n",
"description": "Yolo11n Object Detection",
"website": "http://www.djl.ai/engines/onnxruntime/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolo11n",
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"rescale": true,
"optApplyRatio": true,
"threshold": 0.6,
"translatorFactory": "ai.djl.modality.cv.translator.YoloV8TranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/yolo11n.zip",
"name": "",
"sha1Hash": "5af1900a48422b91d585bf096b1a03da310e6ed3",
"size": 9499137
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/pose_estimation",
"groupId": "ai.djl.pytorch",
"artifactId": "yolo11n-pose",
"name": "yolo11n pose",
"description": "yolo11n Pose Estimation",
"website": "http://www.djl.ai/engines/pytorch/pytorch-model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolo11n-pose",
"properties": {
},
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"threshold": 0.25,
"translatorFactory": "ai.djl.modality.cv.translator.YoloPoseTranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/yolo11n-pose.zip",
"sha1Hash": "73fda7970f476d19a8daf4d0827caeaa2f203487",
"name": "",
"size": 10371051
}
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.optModelUrls("djl://ai.djl.pytorch/yolov8n-seg")
.optModelUrls("djl://ai.djl.pytorch/yolo11n-seg")
.optTranslatorFactory(new YoloSegmentationTranslatorFactory())
.optProgress(new ProgressBar())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static Joints[] predict() throws IOException, ModelException, TranslateEx
Criteria<Image, Joints[]> criteria =
Criteria.builder()
.setTypes(Image.class, Joints[].class)
.optModelUrls("djl://ai.djl.pytorch/yolov8n-pose")
.optModelUrls("djl://ai.djl.pytorch/yolo11n-pose")
.optTranslatorFactory(new YoloPoseTranslatorFactory())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
import java.nio.file.Paths;

/** An example of inference using an yolov8 model. */
public final class Yolov8Detection {
public final class YoloDetection {

private static final Logger logger = LoggerFactory.getLogger(Yolov8Detection.class);
private static final Logger logger = LoggerFactory.getLogger(YoloDetection.class);

private Yolov8Detection() {}
private YoloDetection() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
DetectedObjects detection = predict();
Expand All @@ -49,11 +49,11 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
Image img = ImageFactory.getInstance().fromFile(imgPath);

// Use DJL OnnxRuntime model zoo model, model can be found:
// https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolov8n/0.0.1/yolov8n.zip
// https://mlrepo.djl.ai/model/cv/object_detection/ai/djl/onnxruntime/yolo11n/0.0.1/yolo11n.zip
Criteria<Path, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Path.class, DetectedObjects.class)
.optModelUrls("djl://ai.djl.onnxruntime/yolov8n")
.optModelUrls("djl://ai.djl.onnxruntime/yolo11n")
.optEngine("OnnxRuntime")
.optArgument("width", 640)
.optArgument("height", 640)
Expand All @@ -76,7 +76,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
DetectedObjects detection = predictor.predict(imgPath);
if (detection.getNumberOfObjects() > 0) {
img.drawBoundingBoxes(detection);
Path output = outputPath.resolve("yolov8_detected.png");
Path output = outputPath.resolve("yolo_detected.png");
try (OutputStream os = Files.newOutputStream(output)) {
img.save(os, "png");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

import java.io.IOException;

public class Yolov8DetectionTest {
public class YoloDetectionTest {

@Test
public void testYolov8Detection() throws ModelException, TranslateException, IOException {
TestRequirements.notGpu();

DetectedObjects result = Yolov8Detection.predict();
DetectedObjects result = YoloDetection.predict();

Assert.assertTrue(result.getNumberOfObjects() >= 1);
Assert.assertTrue(result.getClassNames().contains("dog"));
Classifications.Classification obj = result.best();
String className = obj.getClassName();
Assert.assertEquals(className, "dog");
Assert.assertTrue(obj.getProbability() > 0.6);
}
}

0 comments on commit 890d980

Please sign in to comment.