diff --git a/README.md b/README.md index 72b8032f6..d84f5a660 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,50 @@ torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \ Note: passing CLI overrides like `--reset_trainer_state` is only necessary if you didn't update those fields in your config. +## Inference + +You can utilize our HuggingFace integration to run inference on the olmo checkpoints: + +```python +from hf_olmo import * # registers the Auto* classes + +from transformers import AutoModelForCausalLM, AutoTokenizer + +olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B") +tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B") + +message = ["Language modeling is "] +inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) +response = olmo.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95) +print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) +``` + +Alternatively, with the huggingface pipeline abstraction: + +```python +from transformers import pipeline +olmo_pipe = pipeline("text-generation", model="allenai/OLMo-7B") +print(olmo_pipe("Language modeling is")) +``` + + +### Inference on finetuned checkpoints + +If you finetune the model using the code above, you can use the conversion script to convert a native OLMo checkpoint to a HuggingFace-compatible checkpoint + +```bash +python hf_olmo/convert_olmo_to_hf.py --checkpoint-dir /path/to/checkpoint +``` + +### Quantization + +```python +olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B", torch_dtype=torch.float16, load_in_8bit=True) # requires bitsandbytes +``` + +The quantized model is more sensitive to typing / cuda, so it is recommended to pass the inputs as inputs.input_ids.to('cuda') to avoid potential issues. + + ## Evaluation -Additional tools for evaluating OLMo models are available at the [OLMo Eval](https://github.com/allenai/ai2-olmo-eval) repo. \ No newline at end of file +Additional tools for evaluating OLMo models are available at the [OLMo Eval](https://github.com/allenai/ai2-olmo-eval) repo.