The goal of this repository is to use the most optimized technic to train from scratch a LLM and to be the fastest at inference time.
I call this llm anemone, but I'm not satisfied with this name. why not MoM for mixture of mixture ?
Using a virtual environment is recommended.
pip install --upgrade torch --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
- Add 1.58 bits linear layer for the fastest inference but with quality loss (code from 1.58 bits)
- Add Galore for the training
- Use Jamba as base architecture (code from jamba)
- Use Mixture of depth (code from github)
- Use Mixture of attention head (code from JetMoE)
- Add a script to train a LLM model from scratch
- Add evaluation script
- Use Multi-Head Mixture of Experts (MH-MoE)
- Fix 1.58 bits linear layer
- Fix inference for MoD
- Use a filtered dataset such as for rho
To test the first model, that has 1.58 bits linear layer, jamba base architecture and moah, you can clone this repo at this commit:
and run the following command:
python infer.py
To test the second model, that has 1.58 bits linear layer, jamba base architecture, moah and mod, you can clone this repo at this commit
You can start the training process by running the following command:
python train.py
and compare the results with the first model.
You can also run the following command to test the inference:
python infer.py
This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.
This model doesn't contain any 1.58 bits linear layer.
The difference between this model and the previous one is the use of a softmax function to weight the token for the mod and this break the causality and that's maybe why the model output no sense text.
You can also run the following command for this commit to test the inference:
python infer.py
This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.
All mamba, routers, moe, mlp are 1.58 bits linear layer. The linear layers in the attention mechanism are not 1.58 bits linear layers.
You can also run the following command to test the inference and change MoMv3 by MoMv3-mixed-precision in the file:
python infer.py --prompt "This is the story of"
To run the full 1.58bits model, you can run the following command:
python infer.py --prompt "This is the story of" --model "MoMv3-1.58bits"
To run the model with mamba and attention in bf16 and the rest in 1.58bits, you can run the following command:
python infer.py --prompt "This is the story of" --model "MoMv3-M-A-mixed-precision"
To run the full bf16 model, you can run the following command:
python infer.py --prompt "This is the story of" --model "MoMv3-bf16"
This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.
All mamba, routers, moe, mlp are 1.58 bits linear layer. The linear layers in the attention mechanism are in bf16 precision.
The total number of parameters is 1.7% in bf16 and the rest in 1.58bits.
The total active parameters is in a first estimation 87M parameters over 1B parameters.
Each mlp layer has 12.4M parameters each token can pass through 7 mlp layers and 7 mlp expert layer which is 2*7mlp layer. For mlp, the number of parameters is 12.4M * 21 = 261.1M parameters. We add the mamba and attention parameters that are near 107M parameters. And only 1/4 of the tokens pass through a block. So the total number of active parameters is 368.1/4 = 87M parameters.
To test the inference, you can run the following command:
python infer.py --prompt "This is the story of" --model "MoMv4-1.58bits"
and
python infer.py --prompt "This is the story of" --model "MoMv4-bf16"
This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.
All mamba, routers, moe, mlp are in bf16 precision.
To test the inference, you can run the following command:
python infer.py --prompt "This is the story of" --model "MoMv5-bf16"
and
python eval.py --model "MoMv5-bf16" --max_seq_length 512
perplexity: 15.02
To evaluate the model, you can run the following command:
python eval.py --model "MoMv4-1.58bits" --max_seq_length 512
which has a loss of 2.62 and a perplexity of 13.77.
You can also run the evaluation for the full bf16 model:
python eval.py --model "MoMv4-bf16" --max_seq_length 512
which has a loss of 2.53 and a perplexity of 12.59.
The bf16 version is a bit better
We can see (here) that the baseline (MoMv3-bf16) has a similar loss curve as the attention and mamba in bf16 and the rest in 1.58bits (MoMv3-M-A-mixed-precision) and the attention in bf16 and the rest in 1.58bits (MoMv3-mixed-precision).
Furthermore, the training is faster when attention is not at 1.58bits, and it takes lesser vram too.
To train a model with long context and a lot of parameters, for a fast and low memory inference, I found that using the jamba architecture and with all linear layer in 1.58bits excepted for the attention mechanism's layers can be a godd strategy. With only 1.7% of parameters in bf16, the model can fit in cheap gpu during inference. Moreover, using all the mixture (moeh, moe and mod) you can train the model faster with only a few active parameters.
Contributions are welcome.
Please open a pull request with the proposed changes.
Apache License 2.0