diff --git a/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java b/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java index 6f03cd24f..683ebe3ba 100644 --- a/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java +++ b/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java @@ -12,25 +12,33 @@ */ package ai.djl.examples.inference; -import ai.djl.Application; +import ai.djl.Model; import ai.djl.ModelException; -import ai.djl.engine.Engine; -import ai.djl.inference.Predictor; -import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.ImageFactory; +import ai.djl.basicmodelzoo.nlp.SimpleTextDecoder; +import ai.djl.basicmodelzoo.nlp.SimpleTextEncoder; import ai.djl.modality.cv.output.DetectedObjects; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; +import ai.djl.modality.nlp.Decoder; +import ai.djl.modality.nlp.Encoder; +import ai.djl.modality.nlp.EncoderDecoder; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.recurrent.LSTM; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.EasyTrain; +import ai.djl.training.Trainer; +import ai.djl.training.TrainingConfig; +import ai.djl.training.dataset.ArrayDataset; +import ai.djl.training.listener.TrainingListener; +import ai.djl.training.loss.Loss; +import ai.djl.training.optimizer.Optimizer; +import ai.djl.training.tracker.Tracker; import ai.djl.translate.TranslateException; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; + /** * An example of inference using an object detection model. * @@ -50,45 +58,70 @@ public final class ObjectDetection { } public static DetectedObjects predict() throws IOException, ModelException, TranslateException { - Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg"); - Image img = ImageFactory.getInstance().fromFile(imageFile); - - String backbone; - if ("TensorFlow".equals(Engine.getInstance().getEngineName())) { - backbone = "mobilenet_v2"; - } else { - backbone = "resnet50"; + try { + Thread.sleep(20000); + } catch (InterruptedException e) { + e.printStackTrace(); } - - Criteria criteria = - Criteria.builder() - .optApplication(Application.CV.OBJECT_DETECTION) - .setTypes(Image.class, DetectedObjects.class) - .optFilter("backbone", backbone) - .optProgress(new ProgressBar()) - .build(); - - try (ZooModel model = ModelZoo.loadModel(criteria)) { - try (Predictor predictor = model.newPredictor()) { - DetectedObjects detection = predictor.predict(img); - saveBoundingBoxImage(img, detection); - return detection; + try (Model model = Model.newInstance("time-series")) { + NDManager nd = model.getNDManager(); + NDArray inputs = + nd.create( + new float[][] { + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, + {3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, + {5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, + {6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}, + {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}, + {8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f}, + {9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f}, + {10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f} + }); + Shape inputShape = inputs.getShape(); + long cnt = inputShape.get(0); + long dur = inputShape.get(1); + long predDur = 2L; + long trainDur = 3L; + long start = dur - trainDur - predDur - 1; + NDArray encoderInputs = + inputs.get(":," + start + ":" + (start + trainDur)) + .reshape(new Shape(cnt, trainDur, 1L)); + NDArray decoderInputs = + inputs.get(":," + (start + trainDur) + ":" + (start + trainDur + predDur)) + .reshape(new Shape(cnt, predDur, 1L)); + int batchSize = 1; + ArrayDataset trainingDataset = + new ArrayDataset.Builder() + .setData(encoderInputs) + .optLabels(decoderInputs) + .setSampling(batchSize, false) + .build(); + Encoder encoder = + new SimpleTextEncoder( + LSTM.builder().setNumStackedLayers(1).setStateSize(2).build()); + Decoder decoder = + new SimpleTextDecoder( + LSTM.builder().setNumStackedLayers(1).setStateSize(2).build(), 1); + EncoderDecoder net = new EncoderDecoder(encoder, decoder); + model.setBlock(net); + Loss loss = Loss.l1Loss(); + Tracker tracker = Tracker.fixed(0.001f); + Optimizer optimizer = Optimizer.sgd().setLearningRateTracker(tracker).build(); + TrainingListener[] listeners = TrainingListener.Defaults.logging(); + TrainingConfig config = + new DefaultTrainingConfig(loss) + .optOptimizer(optimizer) + .addTrainingListeners(listeners); + int numEpochs = 10; + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(encoderInputs.getShape(), decoderInputs.getShape()); + for (int epoch = 0; epoch < numEpochs; epoch++) { + EasyTrain.fit(trainer, numEpochs, trainingDataset, null); + } } } - } - - private static void saveBoundingBoxImage(Image img, DetectedObjects detection) - throws IOException { - Path outputDir = Paths.get("build/output"); - Files.createDirectories(outputDir); - - // Make image copy with alpha channel because original image was jpg - Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB); - newImage.drawBoundingBoxes(detection); - - Path imagePath = outputDir.resolve("detected-dog_bike_car.png"); - // OpenJDK can't save jpg with alpha channel - newImage.save(Files.newOutputStream(imagePath), "png"); - logger.info("Detected objects image has been saved in: {}", imagePath); + return null; } }