Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use .pb in tensorflow's java api #1244

Open
zaryabRiasat opened this issue Dec 1, 2023 · 0 comments
Open

Unable to use .pb in tensorflow's java api #1244

zaryabRiasat opened this issue Dec 1, 2023 · 0 comments

Comments

@zaryabRiasat
Copy link

I'm trying to use this pre-trained model in Java. I'm using Intellij Idea and I've added library dependency of TensorFlow and Added OpenCV via project structure.

libraryDependency += "org.tensorflow" % "tensorflow" % "1.15.0"

I've downloaded the VGGFace2 pre-trained model, and trying to use its .pb file and find face embeddings.

Code:

import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.imgcodecs.Imgcodecs;
import org.tensorflow.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.file.Paths;
import java.nio.file.Files;
import java.io.IOException;

public class DirectTensorflowTest {
    public static void main(String[] args) {
        // Load OpenCV library
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        // Path to the FaceNet model
        String modelDir = "/home/zaryab/Downloads/20170512-110547";

        try {
            // Read the FaceNet model graph
            byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "20170512-110547.pb"));

            // Import the graph definition into TensorFlow Graph
            Graph g = new Graph();
            g.importGraphDef(graphDef);

            // Create a TensorFlow session with the imported graph
            try (Session s = new Session(g)) {
                // Load an image using OpenCV (replace this with your image loading logic)
                String imagePath = "/home/zaryab/Desktop/IMG_20231124_171435.jpg";
                Mat openCVMat = Imgcodecs.imread(imagePath);

                // Convert OpenCV Mat to float array
                float[] floatArray = convertMatToFloatArray(openCVMat);

                // Byte array to TensorFlow Tensor
                ByteBuffer byteBuffer = ByteBuffer.allocate(floatArray.length * Float.BYTES);
                byteBuffer.asFloatBuffer().put(floatArray);
                byte[] imageByte = byteBuffer.array();

                FloatBuffer fb = ByteBuffer.wrap(imageByte).asFloatBuffer();
                Tensor<Float> imageF = Tensor.create(new long[]{1, openCVMat.rows(), openCVMat.cols(), 1}, fb);

                Tensor<Boolean> falseTensor = Tensors.create(false);

                // Run the session to get embeddings
                Tensor<Float> result = s.runner()
                        .feed("input", imageF)
                        .feed("phase_train", falseTensor)
                        .fetch("embeddings")
                        .run()
                        .get(0)
                        .expect(Float.class);

                // Access the embeddings
                float[][] embeddings = new float[1][(int) result.shape()[1]]; // Assuming shape[1] gives the embedding size
                result.copyTo(embeddings);

                System.out.println(embeddings);

                result.close();
                imageF.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    // Function to convert OpenCV Mat to float array
    public static float[] convertMatToFloatArray(Mat mat) {
        int rows = mat.rows();
        int cols = mat.cols();

        float[] floatArray = new float[rows * cols];

        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                floatArray[i * cols + j] = (float) (mat.get(i, j)[0] / 255.0); // Normalize pixel values between 0 and 1
            }
        }

        return floatArray;
    }

    // Function to read all bytes from a file
    public static byte[] readAllBytesOrExit(java.nio.file.Path path) throws IOException {
        return Files.readAllBytes(path);
    }
}

I've taken some help from https://github.com/davidsandberg/facenet/issues/659 but still getting error:

2023-12-01 11:19:51.689467: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at conv_ops.cc:491 : Invalid argument: input depth must be evenly divisible by filter depth: 1 vs 3
Exception in thread "main" java.lang.IllegalArgumentException: input depth must be evenly divisible by filter depth: 1 vs 3
	 [[{{node InceptionResnetV1/Conv2d_1a_3x3/convolution}}]]
	at org.tensorflow.Session.run(Native Method)
	at org.tensorflow.Session.access$100(Session.java:48)
	at org.tensorflow.Session$Runner.runHelper(Session.java:326)
	at org.tensorflow.Session$Runner.run(Session.java:276)
	at DirectTensorflowTest.main(DirectTensorflowTest.java:111)

Someone Guide me how can I successfully use this model in java using Tensorflow?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant