Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow js models? #169

Closed
Mattk70 opened this issue Oct 1, 2023 · 11 comments
Closed

Tensorflow js models? #169

Mattk70 opened this issue Oct 1, 2023 · 11 comments

Comments

@Mattk70
Copy link
Contributor

Mattk70 commented Oct 1, 2023

Hi Stefan, et al.

You've done great work here, it's really inspiring! I wanted to ask if there wsa any prospect of releasing the latest models in a tensorflowjs format, so they could be run in a web browser? I've seen this repo: https://github.com/kahst/BirdNET-Electron but it was last updated about 4 yewars ago.

@Josef-Haupt
Copy link
Collaborator

We don't have a workflow to convert models right now, but you can try for yourself. There is a guide here.

@Mattk70
Copy link
Contributor Author

Mattk70 commented Oct 2, 2023

Thamks for getting back to me @Josef-Haupt ,

I did try the conversion, this is the command I used:

tensorflowjs_converter --input_format=tf_saved_model --output_node_names='CLASS_DENSE_LAYER' --saved_model_tags=serve --signature_name='basic' ./checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model/ ./js_model

On execution, I get the expected whinging about gradients, but converstion fails with:

ValueError: Unsupported Ops in the model after optimization _FusedBatchNormEx

I wondered if you had implemented custom operations in the model? That would be one reason I would get this error.

@kahst
Copy link
Owner

kahst commented Dec 6, 2023

It's been a while and I just redoscovered this issue: We do have a TFJS version of BirdNET ready to use, including an example on how to run it. You can find the model and the code here: https://github.com/kahst/BirdNET-Analyzer/tree/main/checkpoints/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_TFJS

I would love to develop something like birdnet.js or BirdNET as npm package, yet, we do lack the skills to do so. If you're interested in attempting that, let us know :)

@Mattk70
Copy link
Contributor Author

Mattk70 commented Jan 22, 2024

@kahst Sorry - I didn't see your reply until this morning. Thanks for sharing, that's awesome!

I've had a little play around, it's looking really great 😄

One suggestion: if you make the following small changes to call method of the MelSpecLayerSimple class in your example, it will support batch predictions, which gives a decent performance boost (up to c. 40% faster inference):

class MelSpecLayerSimple extends tf.layers.Layer {
    constructor(config) {
        super(config);

        // Initialize parameters
        this.sampleRate = config.sampleRate;
        this.specShape = config.specShape;
        this.frameStep = config.frameStep;
        this.frameLength = config.frameLength;
        this.fmin = config.fmin;
        this.fmax = config.fmax;
        this.melFilterbank = tf.tensor2d(config.melFilterbank);
    }

    build(inputShape) {
        // Initialize trainable weights, for example:
        this.magScale = this.addWeight(
            'magnitude_scaling',
            [],
            'float32',
            tf.initializers.constant({ value: 1.23 })
        );

        super.build(inputShape);
    }

    // Compute the output shape of the layer
    computeOutputShape(inputShape) {
        return [inputShape[0], this.specShape[0], this.specShape[1], 1];
    }

    // Define the layer's forward pass
    call(inputs) {
        return tf.tidy(() => {
            // inputs is a tensor representing the input data
            inputs = inputs[0];
            const inputList = tf.split(inputs, inputs.shape[0])
            const specBatch = inputList.map(input =>{
                input = input.squeeze();
                // Normalize values between -1 and 1
                input = tf.sub(input, tf.min(input, -1, true));
                input = tf.div(input, tf.max(input, -1, true).add(0.000001));
                input = tf.sub(input, 0.5);
                input = tf.mul(input, 2.0);

                // Perform STFT
                let spec = tf.signal.stft(
                    input,
                    this.frameLength,
                    this.frameStep,
                    this.frameLength,
                    tf.signal.hannWindow,
                );

                // Cast from complex to float
                spec = tf.cast(spec, 'float32');

                // Apply mel filter bank
                spec = tf.matMul(spec, this.melFilterbank);

                // Convert to power spectrogram
                spec = spec.pow(2.0);

                // Apply nonlinearity
                spec = spec.pow(tf.div(1.0, tf.add(1.0, tf.exp(this.magScale.read()))));

                // Flip the spectrogram
                spec = tf.reverse(spec, -1);

                // Swap axes to fit input shape
                spec = tf.transpose(spec)

                // Adding the channel dimension
                spec = spec.expandDims(-1);

                // Adding batch dimension
                //spec = spec.expandDims(0);

                return spec;
            })
            return tf.stack(specBatch)
        });
    }

    // Optionally, include the `className` method to provide a machine-readable name for the layer
    static get className() {
        return 'MelSpecLayerSimple';
    }
}

 // Register the custom layer with TensorFlow.js
 tf.serialization.registerClass(MelSpecLayerSimple);

I can create a PR if you like. But I ran out of time today! Meanwhile, I have an Electron App which now has your BirdNet model as a backend option, taking advantage of the metadata model's 'species location' filtering. Happy days!

@kahst
Copy link
Owner

kahst commented Jan 22, 2024

Awesome! Yes, please create a PR. Also, let us know more about your Electron app, we're always looking for cool projects to feature in our showcase. Thanks!

@Mattk70
Copy link
Contributor Author

Mattk70 commented Jan 25, 2024

@kahst @Josef-Haupt - hey guys, sorry to bother you again. I have a question about your TFJS conversion, motivated by the finding that a comparison shows the results from the TFJS model are similar, but far from identical to those given by BirdNET-Analyser-GUI.exe. A link to the TFJS model output from the example soundscape.wav file from this repo is here:
TFJS_detections.csv
And for convenience, the python version output:
BirdNET-Analyser-GUI.csv

For the avoidance of doubt, I have used these settings in the TFJS analysis:

  • no location filters
  • confidence 0.1
  • no sigmoid
    (the TFJS model has a trainable sigmoid activation head so already outputs probabilities - and this architectural difference will explain some of the difference - it may even be the answer?)

And in the GUI

  • sigmoid 1 (default)
  • all species (i.e. no filter)
  • confidence 0.1 (default)

In both cases batchsize and threads settings make zero difference, as you would expect.

I know there could be myriad reasons for this - there are many possibilities at my end. However, before I spend too much time comparing the outputs of the individual libraries involved, I wanted to see if you'd used any aggressive quantization in the TFJS conversion? Like UINT8 in some activation layers for example? Using Netron, I noticed some of the layers have "QUANT" and others 'NO_QUANT' appended to their names, so I thought it would be a good thing to check.

It's worth saying, I'm not sure this is even a bug, it may be expected behaviour - both models make good predictions, but if am to release an App which claims to offer the BirdNET 2.4 model for predictions, I would like a handle on why the results are this different.

Thanks again for your input, I really appreciate it! The app I've been developing was initially inspired by BirdNET-Electron. You may find some of the design elements familiar Stefan, although it's come a long way since. You can read about, even download it, here: https://chirpity.mattkirkland.co.uk

(When I release a version with the new BirdNET model, the app will auto-update)

@kahst
Copy link
Owner

kahst commented Jan 26, 2024

Hmmm, there shouldn't be a significant difference between models. We do test our models after conversion to make sure the scores align with the original Keras version - even after quantization. The TFJS model is not quantized and should yield scores that are very similar to the TFLite 32bit model. There might be something off with the GUI though. We'll investigate and keep you posted.

@Mattk70
Copy link
Contributor Author

Mattk70 commented Jan 26, 2024

Hi @kahst , thanks again for getting back to me. I just ran the same analysis of the soundscape.wav file through the CLI and the results are identical to the GUI results posted above.

I think next steps would be for me to build/run a minimal implementation for the TFJS model such that it can accept the soundscape.wav, chunk it up into 3 second clips and output predictions. If the results are still different I'll share that code back with you, otherwise I'll know it's an issue at my end.

@tphakala
Copy link

@Mattk70 I bet difference comes from differences in sigmoid sensitivity implementation you mentioned

@Mattk70
Copy link
Contributor Author

Mattk70 commented Jan 26, 2024

@kahst @tphakala OK. I have done a validation of the TFJS model using the flask app and sample.wav

Analyzer.py output for ths file:
Start (s),End (s),Scientific name,Common name,Confidence
0,3.0,Acanthis cabaret,Lesser Redpoll,0.9444
0,3.0,Acanthis flammea,Common Redpoll,0.7997
0,3.0,Acanthis hornemanni,Hoary Redpoll,0.0590

TFJS via the flask app in this repo* produces the following:
Results:
Acanthis cabaret_Lesser Redpoll: 0.9447994828224182
Acanthis flammea_Common Redpoll: 0.7999176383018494
Acanthis hornemanni_Hoary Redpoll: 0.05872755125164986

Whilst these are different they are minimally so, which is what I'd have expected.
Sadly, my implementation produces:

Start (s) | End (s) | Scientific name | Common name | Confidence
0 | 3 | Acanthis cabaret | Lesser Redpoll | 0.96
0 | 3 | Acanthis flammea | Common Redpoll | 0.84
0 | 3 | Acanthis hornemanni | Hoary Redpoll | 0.06

So, it's an issue with my code.

The Flask app has a couple of issues I needed to correct before I could get it to work.

  1. If you create a new AudioContext with no parameters, its default behaviour is to resample audio to 44100Hz - consequently, the resulting tensor is too short. I added {sampleRate: 48000}
  2. If you try to create the audioContext onload, it breaks default browser policies for CORS requests around auto playing media, and errors out. I added a button to generate the required user interaction, which is disabled until the page has loaded

I'll do a PR with these changes. And then start looking into the dang problem with my code. At least, if there's an upside, I'd never have spotted this bug if I hadn't integrated BirdNET ;-)

@Mattk70
Copy link
Contributor Author

Mattk70 commented Jan 26, 2024

Nailed it! My native model uses 24000Hz sample rate, and there was one place in the code I'd missed, which was switching the audioContext in the audio pipeline to this with BirdNET, before then upsampling the audioBuffer back to 48k. Now that's fixed, the predictions in the soundscape.wav of the TFJS and python models are almost identical.

Fixed comparison.csv

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants