Bumblebee is a small vision transformer that runs locally in the browser with WASM, implemented in Rust using Candle.
Bumblebee achieves 56.11% accuracy on the CIFAR-10 dataset. Training took 40 epochs on an RTX 4090 GPU.
To train the model:
cargo bin --run train
To compile into WASM and serve:
./build.sh
python -m http.server -d site/dist
Note: copy the weights.safetensors
file into site/public
after training finishes.
- train on better gpu
- add wasm-pack-plugin to webpack config, move crate into site
- add dropout to training
- refactor code, move model/training params to config class