Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -616,22 +616,65 @@ than the JSON schemas used for tools, no helper functions are necessary.
Here's an example of a RAG template in action:

```python
document1 = {
"title": "The Moon: Our Age-Old Foe",
"contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
}
from transformers import AutoTokenizer, AutoModelForCausalLM

document2 = {
"title": "The Sun: Our Age-Old Friend",
"contents": "Although often underappreciated, the sun provides several notable benefits..."
}
# Load the model and tokenizer
model_id = "CohereForAI/c4ai-command-r-v01-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
device = model.device # Get the device the model is loaded on

model_input = tokenizer.apply_chat_template(
messages,
documents=[document1, document2]
)
# Define conversation input
conversation = [
{"role": "user", "content": "What has Man always dreamed of?"}
]

# Define documents for retrieval-based generation
documents = [
{
"title": "The Moon: Our Age-Old Foe",
"text": "Man has always dreamed of destroying the moon. In this essay, I shall..."
},
{
"title": "The Sun: Our Age-Old Friend",
"text": "Although often underappreciated, the sun provides several notable benefits..."
}
]

# Tokenize conversation and documents using a RAG template, returning PyTorch tensors.
input_ids = tokenizer.apply_chat_template(
conversation=conversation,
documents=documents,
chat_template="rag",
tokenize=True,
add_generation_prompt=True,
return_tensors="pt").to(device)

# Generate a response
gen_tokens = model.generate(
input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.3,
)

# Decode and print the generated text along with generation prompt
gen_text = tokenizer.decode(gen_tokens[0])
print(gen_text)
```

<Tip>

The `documents` input for retrieval-augmented generation is not widely supported, and many models have chat templates which simply ignore this input.

To verify if a model supports the `documents` input, you can read its model card, or `print(tokenizer.chat_template)` to see if the `documents` key is used anywhere.

One model class that does support it, though, is Cohere's [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-08-2024) and [Command-R+](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024), through their `rag` chat template. You can see additional examples of grounded generation using this feature in their model cards.

</Tip>



## Advanced: How do chat templates work?

The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
Expand Down