Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. EasyLM can scale up LLM training to hundreds of TPU/GPU accelerators by leveraging JAX's pjit functionality.
Building on top of Hugginface's transformers and datasets, this repo provides an easy to use and easy to customize codebase for training large language models without the complexity in many other frameworks.
EasyLM is built with JAX/Flax. By leveraging JAX's pjit utility, EasyLM is able to train large models that don't fit on a single accelerator by sharding the model weights and training data across multiple accelerators. Currently, EasyLM supports multiple TPU/GPU training in a single host as well as multi-host training on Google Cloud TPU Pods.
The fork adds support for some RLHF methods, such as Direct Preference Optimization (DPO) and Proximal Policy Optimization (PPO) for some models. It has been used for models, such as Tulu 2, with more coming soon.
The original EasyLM
is no longer supported. This is largely supported for research use and doesn't
come with standard maitenance rules and guidelines.
Currently, the following models are supported to some capacity.
For the core models used, for now just Llama, the directory EasyLM/models/llama/
contains scripts such as convert_hf_to_easylm.py
and convert_easylm_to_hf.py
for easy integration with other libraries.
Models trained here can also be evaluated with AllenAI's Open Instruct repository via scripts/submit_open_instruct_eval.py
.
The directory is organized as follows:
├── README.md <- The top-level README for researchers using this project
├── beaker_configs/ <- [AI2 only] example config and automatically generated experiment configs
├── conversion_scripts/ <- Scripts for creating .json datasets from HuggingFace format (see `create_preference_data.sh`)
├── docs/ <- New and existing documentation
| ├── ai2.md ├── In depth tutorial on how to use EasyLM on AI2 infra
| └── *.md └── Preexisting docs
├── EasyLM/ <- Core utils and modeling files
| ├── models/ ├── Packages and scripts specific to each model's architecture
| ├── scripts/ ├── Benchmarking and evaluation scripts
| └── *.py └── Utilities and tools
├── examples/ <- Bash scripts for running EasyLM training
├── scripts/ <- Misc. extra scripts for benchmarking and evaluation.
└── LICENSE
The original authors are running an unofficial Discord community (unaffiliated with Google) for discussion related to training LLMs in JAX. Follow this link to join the Discord server. They have dedicated channels for several JAX based LLM frameworks, include EasyLM, JaxSeq, Alpa and Levanter.
OpenLLaMA is our permissively licensed reproduction of LLaMA which can be used for commercial purposes. Check out the project main page here. The OpenLLaMA can serve as drop in replacement for the LLaMA weights in EasyLM. Please refer to the LLaMA documentation for more details.
Koala is our new chatbot fine-tuned on top of LLaMA. If you are interested in our Koala chatbot, you can check out the blogpost and documentation for running it locally.
Tulu 2 is a suite of DPO aligned models built on top of the Llama 2 suite.
The installation method differs between GPU hosts and Cloud TPU hosts. The first step is to pull from GitHub.
git clone https://github.com/hamishivi/EasyLM.git
cd EasyLM
export PYTHONPATH="${PWD}:$PYTHONPATH"
The GPU environment can be installed via Anaconda.
conda env create -f scripts/gpu_environment.yml
conda activate EasyLM
The TPU host VM comes with Python and PIP pre-installed. Simply run the following script to set up the TPU host.
./scripts/tpu_vm_setup.sh
The EasyLM documentations can be found in the docs directory.
If you found EasyLM useful in your research or applications, please cite using the following BibTeX:
@software{geng2023easylm,
author = {Geng, Xinyang},
title = {EasyLM: A Simple And Scalable Training Framework for Large Language Models},
month = March,
year = 2023,
url = {https://github.com/young-geng/EasyLM}
}
And the citation for this fork specifically, if you wish:
@software{hamishivi2023easylmfork,
author = {Ivison, Hamish and Wang, Yizhong, and Pyatkin, Valentina and Liu, Jiacheng and Lu, Jiasen and Wu, Zeqiu},
title = {EasyLM-Fork: A Simple And Scalable Training Framework for Large Language Models},
month = October,
year = 2023,
url = {https://github.com/hamishivi/EasyLM}
}
- The LLaMA implementation is from JAX_llama
- The JAX/Flax GPT-J and RoBERTa implementation are from transformers
- Most of the JAX utilities are from mlxu
- The codebase is heavily inspired by JAXSeq