Skip to content

Commit

Permalink
[examples] Prepare for MXNet deprecation (#3157)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored May 6, 2024
1 parent 4b4d031 commit 853c252
Show file tree
Hide file tree
Showing 65 changed files with 214 additions and 278 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public Shape[] getOutputShapes(Shape[] inputShapes) {
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes, DataType[] dataTypes) {
try (NDManager manager = NDManager.newBaseManager()) {
try (NDManager manager = NDManager.newBaseManager("PyTorch")) {
NDList list = new NDList();
for (int i = 0; i < inputShapes.length; i++) {
list.add(
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/train_transfer_fresh_fruit.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) {

DefaultTrainingConfig config = new DefaultTrainingConfig(new SoftmaxCrossEntropy("SoftmaxCrossEntropy"))
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(1))
.optDevices(Engine.getEngine("PyTorch").getDevices(1))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
Expand Down Expand Up @@ -68,14 +68,14 @@ public static String predict() throws IOException, TranslateException, ModelExce
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optFilter("backbone", "bert")
.optEngine(Engine.getDefaultEngineName())
.optEngine("PyTorch")
.optDevice(Device.cpu())
.optProgress(new ProgressBar())
.build();

try (ZooModel<QAInput, String> model = criteria.loadModel()) {
try (Predictor<QAInput, String> predictor = model.newPredictor()) {
return predictor.predict(input);
}
try (ZooModel<QAInput, String> model = criteria.loadModel();
Predictor<QAInput, String> predictor = model.newPredictor()) {
return predictor.predict(input);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public static Classifications predict()
Criteria.builder()
.optApplication(Application.NLP.SENTIMENT_ANALYSIS)
.setTypes(String.class, Classifications.class)
.optEngine("PyTorch")
// This model was traced on CPU and can only run on CPU
.optDevice(Device.cpu())
.optProgress(new ProgressBar())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl.examples.inference;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.audio.Audio;
Expand Down Expand Up @@ -56,6 +57,7 @@ public static String predict() throws IOException, ModelException, TranslateExce
Criteria.builder()
.setTypes(Audio.class, String.class)
.optModelUrls(url)
.optDevice(Device.cpu()) // torchscript model only support CPU
.optTranslatorFactory(new SpeechRecognitionTranslatorFactory())
.optModelName("wav2vec2.ptl")
.optEngine("PyTorch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.examples.inference.clip;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
Expand Down Expand Up @@ -45,6 +46,7 @@ public ClipModel() throws ModelException, IOException {
.optModelUrls("https://resources.djl.ai/demo/pytorch/clip.zip")
.optTranslator(new NoopTranslator())
.optEngine("PyTorch")
.optDevice(Device.cpu()) // torchscript model only support CPU
.build();
clip = criteria.loadModel();
imageFeatureExtractor = clip.newPredictor(new ImageTranslator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static Classifications predict() throws IOException, ModelException, Tran
Image img = ImageFactory.getInstance().fromFile(imageFile);

String modelName = "mlp";
try (Model model = Model.newInstance(modelName)) {
try (Model model = Model.newInstance(modelName, "PyTorch")) {
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

// Assume you have run TrainMnist.java example, and saved model in build/model folder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
Expand Down Expand Up @@ -54,19 +53,12 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);

String backbone;
if ("TensorFlow".equals(Engine.getDefaultEngineName())) {
backbone = "mobilenet_v2";
} else {
backbone = "resnet50";
}

Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", backbone)
.optEngine(Engine.getDefaultEngineName())
.optFilter("backbone", "mobilenet_v2")
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
package ai.djl.examples.inference.nlp;

import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
Expand All @@ -22,7 +22,6 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.DeferredTranslatorFactory;
import ai.djl.translate.TranslateException;
Expand All @@ -39,20 +38,13 @@ public final class RollingBatch {

private RollingBatch() {}

public static void main(String[] args)
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
public static void main(String[] args) throws ModelException, IOException, TranslateException {
String[] ret = seqBatchSchedulerWithPyTorchContrastive();
logger.info("{}", ret[0]);
}

public static String[] seqBatchSchedulerWithPyTorchContrastive()
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
throws ModelException, IOException, TranslateException {
String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip";

Criteria<NDList, CausalLMOutput> criteria =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.examples.inference.nlp;

import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
Expand Down Expand Up @@ -161,10 +162,7 @@ public static String[] generateTextWithPyTorchBeam()
}

public static String[] generateTextWithOnnxRuntimeBeam()
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
throws ModelException, IOException, TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
long padTokenId = 220;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.utils.DynamicBuffer;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
Expand Down Expand Up @@ -63,7 +62,6 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -82,9 +80,7 @@ public static void main(String[] args) throws IOException, TranslateException, M

public static Map<String, Float> predict()
throws IOException, TranslateException, ModelException {
Engine engine = Engine.getInstance();
NDManager manager = engine.newBaseManager();
String engineName = engine.getEngineName().toLowerCase(Locale.ROOT);
NDManager manager = NDManager.newBaseManager("MXNet");

// To use local dataset, users can load data as follows
// Repository repository = Repository.newInstance("local_dataset",
Expand All @@ -102,12 +98,13 @@ public static Map<String, Float> predict()
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008#file-m5torch-py

// Here you can also use local file: modelUrl = "LOCAL_PATH/deepar.pt";
String modelUrl = "djl://ai.djl." + engineName + "/deepar/0.0.1/m5forecast";
String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls(modelUrl)
.optEngine("MXNet")
.optTranslatorFactory(new DeepARTranslatorFactory())
.optArgument("prediction_length", predictionLength)
.optArgument("freq", "W")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ai.djl.basicdataset.tabular.TabularDataset;
import ai.djl.basicdataset.tabular.TabularResults;
import ai.djl.basicmodelzoo.tabular.TabNet;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
Expand Down Expand Up @@ -55,7 +54,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
// Construct a tabNet instance
Block tabNet = TabNet.builder().setInputDim(5).setOutDim(1).build();

try (Model model = Model.newInstance("tabNet")) {
try (Model model = Model.newInstance("tabNet", arguments.getEngine())) {
model.setBlock(tabNet);

// get the training and validation dataset
Expand Down Expand Up @@ -103,13 +102,12 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
});

return new DefaultTrainingConfig(new TabNetRegressionLoss())
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}

private static TabularDataset getDataset(Arguments arguments)
throws IOException, TranslateException {
private static TabularDataset getDataset(Arguments arguments) throws IOException {
AirfoilRandomAccess.Builder airfoilBuilder = AirfoilRandomAccess.builder();

// only train dataset is available, so we get train dataset and split them
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package ai.djl.examples.training;

import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.BertCodeDataset;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -59,7 +58,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
dataset.prepare();

// Create model & trainer
try (Model model = createBertPretrainingModel(dataset.getVocabularySize())) {
try (Model model = createBertPretrainingModel(arguments, dataset.getVocabularySize())) {

TrainingConfig config = createTrainingConfig(arguments);
try (Trainer trainer = model.newTrainer(config)) {
Expand All @@ -74,15 +73,15 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
}
}

private static Model createBertPretrainingModel(long vocabularySize) {
private static Model createBertPretrainingModel(Arguments arguments, long vocabularySize) {
Block block =
new BertPretrainingBlock(
BertBlock.builder()
.micro()
.setTokenDictionarySize(Math.toIntExact(vocabularySize)));
block.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT);

Model model = Model.newInstance("Bert Pretraining");
Model model = Model.newInstance("Bert Pretraining", arguments.getEngine());
model.setBlock(block);
return model;
}
Expand All @@ -108,7 +107,7 @@ private static TrainingConfig createTrainingConfig(BertArguments arguments) {
.build();
return new DefaultTrainingConfig(new BertPretrainingLoss())
.optOptimizer(optimizer)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addTrainingListeners(Defaults.logging());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Model;
import ai.djl.basicdataset.nlp.GoEmotions;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.BertGoemotionsDataset;
import ai.djl.modality.nlp.embedding.EmbeddingException;
Expand Down Expand Up @@ -67,7 +66,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
dataset.prepare();

// Create model & trainer
try (Model model = createBertPretrainingModel(dataset.getVocabularySize())) {
try (Model model = createBertPretrainingModel(arguments, dataset.getVocabularySize())) {
TrainingConfig config = createTrainingConfig(arguments);
try (Trainer trainer = model.newTrainer(config)) {
// Initialize training
Expand Down Expand Up @@ -105,19 +104,19 @@ private static TrainingConfig createTrainingConfig(
.build();
return new DefaultTrainingConfig(new BertPretrainingLoss())
.optOptimizer(optimizer)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addTrainingListeners(TrainingListener.Defaults.logging());
}

private static Model createBertPretrainingModel(long vocabularySize) {
private static Model createBertPretrainingModel(Arguments arguments, long vocabularySize) {
Block block =
new BertPretrainingBlock(
BertBlock.builder()
.micro()
.setTokenDictionarySize(Math.toIntExact(vocabularySize)));
block.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT);

Model model = Model.newInstance("Bert Pretraining");
Model model = Model.newInstance("Bert Pretraining", arguments.getEngine());
model.setBlock(block);
return model;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.CaptchaDataset;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
Expand Down Expand Up @@ -63,7 +62,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
return null;
}

try (Model model = Model.newInstance("captcha")) {
try (Model model = Model.newInstance("captcha", arguments.getEngine())) {
model.setBlock(getBlock());

// get training and validation dataset
Expand Down Expand Up @@ -107,7 +106,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {

DefaultTrainingConfig config =
new DefaultTrainingConfig(loss)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addEvaluators(loss.getComponents())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
Expand Down
Loading

0 comments on commit 853c252

Please sign in to comment.