diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5db0f53fcaa..16036f7d83f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -53,8 +53,6 @@ title: Community Tutorials - local: lora_without_regret title: LoRA Without Regret - - local: sentiment_tuning - title: Sentiment Tuning - local: multi_adapter_rl title: Multi Adapter RLHF title: Examples diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 488a99d2924..898459ba783 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -36,8 +36,6 @@ These notebooks are easier to run and are designed for quick experimentation wit Legacy / Older Notebooks - [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. -- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. -- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. ## Scripts diff --git a/docs/source/sentiment_tuning.md b/docs/source/sentiment_tuning.md deleted file mode 100644 index 025ffba0da5..00000000000 --- a/docs/source/sentiment_tuning.md +++ /dev/null @@ -1,31 +0,0 @@ -# Sentiment Tuning Examples - -The notebooks and scripts in these examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`). - -Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): - -| File | Description | -| --- |--- | -| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | -| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | -| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. | - -## Usage - -```bash -# 1. run directly -python examples/scripts/ppo.py -# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed) -accelerate config # will prompt you to define the training configuration -accelerate launch examples/scripts/ppo.py # launches training -# 3. get help text and documentation -python examples/scripts/ppo.py --help -# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16 -python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16 -``` - -Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). - -## Few notes on multi-GPU - -To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`. diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md index 0c55f02e5ae..9aee78682c0 100644 --- a/examples/notebooks/README.md +++ b/examples/notebooks/README.md @@ -13,5 +13,3 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t Legacy / Older Notebooks - [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. -- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. -- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. diff --git a/examples/notebooks/gpt2-sentiment-control.ipynb b/examples/notebooks/gpt2-sentiment-control.ipynb deleted file mode 100644 index fa19ed3f1c1..00000000000 --- a/examples/notebooks/gpt2-sentiment-control.ipynb +++ /dev/null @@ -1,897 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tune GPT2 to generate controlled sentiment reviews\n", - "> Optimise GPT2 to produce IMDB movie reviews with controlled sentiment using a BERT sentiment classifier for rewards.\n", - "\n", - "**WARNING:** We often experienced loss spikes in this examples which caused model training to fail or slow down. There is a [GitHub issue](https://github.com/lvwerra/trl/issues/101) to track the issue." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - "\n", - "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", - "
\n", - "\n", - "\n", - "The experiment setup is very similar to the positive sentiment notebook. However, in this notebook we fine-tune GPT2 (small) to generate **controlled** movie reviews based on the IMDB dataset. The model gets the target sentiment and 5 tokens from a real review and is tasked to produce continuations with the targeted sentiment. The reward for the continuations is calculated with the logits of a BERT sentiment classifier. That reward is then used for PPO training." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup experiment" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import random\n", - "import torch\n", - "import wandb\n", - "import time\n", - "import os\n", - "from tqdm import tqdm\n", - "import numpy as np\n", - "import pandas as pd\n", - "from random import choices\n", - "import matplotlib.pyplot as plt\n", - "\n", - "tqdm.pandas()\n", - "\n", - "from datasets import load_dataset\n", - "\n", - "from transformers import AutoTokenizer, pipeline\n", - "\n", - "from trl import (\n", - " PPOTrainer,\n", - " PPOConfig,\n", - " AutoModelForCausalLMWithValueHead,\n", - " create_reference_model,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "sentiment_pipe_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\"}\n", - "\n", - "config = PPOConfig(\n", - " model_name=\"lvwerra/gpt2-imdb\",\n", - " steps=51200,\n", - " learning_rate=1.41e-5,\n", - " remove_unused_columns=False,\n", - " log_with=\"wandb\",\n", - ")\n", - "\n", - "txt_in_len = 5\n", - "txt_out_len = 20\n", - "seed = 1" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "np.random.seed(seed)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", - "https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load data and models" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load pre-trained GPT2 language models" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", - "gpt2_ref_model = create_reference_model(gpt2_model)\n", - "gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", - "\n", - "gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load IMDB dataset\n", - "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 500 characters long and take the first 1000 characters of each comment. The first filter we apply to avoid comments that are less than `txt_in_len` token long and the second to avoid tokenizing way more text than we actually need." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Found cached dataset imdb (/home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n", - "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-d314b4c14499bf03.arrow\n", - "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-0d5fcb05c95b1186.arrow\n" - ] - }, - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['review', 'sentiment'],\n", - " num_rows: 22578\n", - "})" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# create the dataset\n", - "#\n", - "dataset = load_dataset(\"stanfordnlp/imdb\", split=\"train\")\n", - "dataset = dataset.rename_columns({\"text\": \"review\", \"label\": \"sentiment\"})\n", - "# make sure the comments are are at least 500 and trim to 1000\n", - "dataset = dataset.filter(lambda x: len(x[\"review\"]) > 500, batched=False)\n", - "dataset = dataset.map(lambda x: {\"review\": x[\"review\"][:1000]}, batched=False)\n", - "\n", - "dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Tokenize IMDB reviews" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We tokenize all IMDB in advance to avoid tokenizing twice. In the first step we encode the queries and slice the first `txt_in_len` tokens. In a second step we decode these tokens back to text for later display." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-383f6ebf0ae41ee4.arrow\n", - "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-f4875ad4fccbbc1f.arrow\n" - ] - } - ], - "source": [ - "dataset = dataset.map(\n", - " lambda x: {\n", - " \"input_ids\": gpt2_tokenizer.encode(\" \" + x[\"review\"], return_tensors=\"pt\")[\n", - " 0, :txt_in_len\n", - " ]\n", - " },\n", - " batched=False,\n", - ")\n", - "dataset = dataset.map(\n", - " lambda x: {\"query\": gpt2_tokenizer.decode(x[\"input_ids\"])}, batched=False\n", - ")\n", - "dataset = dataset[:20480]\n", - "\n", - "from datasets import Dataset\n", - "\n", - "dataset = Dataset.from_dict(dataset)\n", - "dataset.set_format(\"pytorch\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 770, 2646, 373, 2192, 7867])" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset[3][\"input_ids\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "def collator(data):\n", - " return dict((key, [d[key] for d in data]) for key in data[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlvwerra\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.13.9" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in /home/leandro_huggingface_co/trl/examples/sentiment/notebooks/wandb/run-20230206_125743-jpcnr7jx" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run comic-music-184 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View project at https://wandb.ai/lvwerra/trl" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run at https://wandb.ai/lvwerra/trl/runs/jpcnr7jx" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "ppo_trainer = PPOTrainer(\n", - " config, gpt2_model, gpt2_ref_model, gpt2_tokenizer, dataset, data_collator=collator\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load BERT classifier\n", - "We load a BERT classifier fine-tuned on the IMDB dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "if ppo_trainer.accelerator.num_processes == 1:\n", - " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", - "else:\n", - " device = ppo_trainer.accelerator.device\n", - "sentiment_pipe = pipeline(\n", - " \"sentiment-analysis\", \"lvwerra/distilbert-imdb\", device=device\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'label': 'NEGATIVE', 'score': 2.3350484371185303},\n", - " {'label': 'POSITIVE', 'score': -2.726576328277588}]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text = \"this movie was really bad!!\"\n", - "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", - "output" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'label': 'POSITIVE', 'score': 2.557040214538574},\n", - " {'label': 'NEGATIVE', 'score': -2.294790267944336}]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text = \"this movie was really good!!\"\n", - "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", - "output" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'label': 'POSITIVE', 'score': 0.8562759160995483},\n", - " {'label': 'NEGATIVE', 'score': -0.7086048126220703}]" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text = \"this movie was a documentary\"\n", - "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", - "output" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The resulting reward signal:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "def extract_pipe_output(outputs):\n", - " positive_logits = []\n", - " for out in outputs:\n", - " for element in out:\n", - " if element[\"label\"] == \"POSITIVE\":\n", - " positive_logits.append(torch.tensor(element[\"score\"]))\n", - " return positive_logits" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "-0.7086048126220703" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output[1][\"score\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Control token dict\n", - "We will append the control token at the beginning of each query to signal the model what the target sentiment is. Each control sequence consists of three tokens:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "ctrl_str = [\"[negative]\", \"[neutral]\", \"[positive]\"]\n", - "device = torch.device(\n", - " \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - ") # this should be handled by accelerate\n", - "ctrl_tokens = dict(\n", - " (s, gpt2_tokenizer.encode(s, return_tensors=\"pt\").squeeze().to(device))\n", - " for s in ctrl_str\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'[negative]': tensor([ 58, 31591, 60], device='cuda:0'),\n", - " '[neutral]': tensor([ 58, 29797, 60], device='cuda:0'),\n", - " '[positive]': tensor([ 58, 24561, 60], device='cuda:0')}" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ctrl_tokens" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reward function" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "def pos_logit_to_reward(logit, task):\n", - " \"\"\"\n", - " Take the positive sentiment logit and scale it for the task.\n", - " task [negative]: reward = -logit\n", - " task [neutral]: reward = -2*abs(logit)+4\n", - " task [positive]: reward = logit\n", - " \"\"\"\n", - " for i in range(len(logit)):\n", - " if task[i] == \"[negative]\":\n", - " logit[i] = -logit[i]\n", - " elif task[i] == \"[neutral]\":\n", - " logit[i] = -2 * torch.abs(logit[i]) + 4\n", - " elif task[i] == \"[positive]\":\n", - " pass\n", - " else:\n", - " raise ValueError(\"task has to be in [0, 1, 2]!\")\n", - " return logit" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The following examples show the rewards for the cases where the classifier logit is 4, -4 and 0 for the three targets `['negative]`, `['neutral]` and `['positive']`. The scaling is not perfect as it differs between neutral and the other two classes. This is something to further investigate in the future. Ideally, one would use the logit output for each class individually, but since there is no dedicated class for neutral this is a workaround." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['[negative]', '[neutral]', '[positive]']\n" - ] - } - ], - "source": [ - "print(ctrl_str)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([-4., -4., 4.])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pos_logit_to_reward(torch.Tensor([4, 4, 4]), ctrl_str)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 4., -4., -4.])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pos_logit_to_reward(torch.Tensor([-4, -4, -4]), ctrl_str)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([-0., 4., 0.])" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pos_logit_to_reward(torch.Tensor([0, 0, 0]), ctrl_str)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Generation settings" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "generation_kwargs = {\n", - " \"min_length\": -1,\n", - " \"top_k\": 0.0,\n", - " \"top_p\": 1.0,\n", - " \"do_sample\": True,\n", - " \"pad_token_id\": gpt2_tokenizer.eos_token_id,\n", - " \"max_new_tokens\": txt_out_len,\n", - " \"eos_token_id\": -1,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optimize model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Steps**\n", - "\n", - "The training loop consists of the following steps:\n", - "1. Get a batch of queries and create random controls\n", - "2. Get the query responses from the policy\n", - "3. Join query and responses and tokenize for BERT analysis\n", - "4. Get sentiments for query/responses from BERT\n", - "5. Optimize policy with PPO using the (query, response, reward) triplet\n", - "6. Log all the training statistics\n", - "\n", - "**Training time**\n", - "\n", - "This step takes **~2h** on a P6000 GPU with the above specified settings." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 8%|▊ | 6/80 [12:44<2:37:54, 128.03s/it]/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1045: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", - " warnings.warn(\n", - "100%|██████████| 80/80 [2:46:39<00:00, 124.99s/it] \n", - " 91%|█████████▏| 73/80 [2:30:39<14:35, 125.03s/it] " - ] - } - ], - "source": [ - "for epoch in range(2):\n", - " for batch in tqdm(ppo_trainer.dataloader):\n", - " (\n", - " logs,\n", - " game_data,\n", - " ) = (\n", - " dict(),\n", - " dict(),\n", - " )\n", - "\n", - " #### prepend a random control token\n", - " task_list = choices(ctrl_str, k=config.batch_size)\n", - " game_data[\"query\"] = [t + q for t, q in zip(task_list, batch[\"query\"])]\n", - " query_tensors = [\n", - " torch.cat((ctrl_tokens[t], input_ids))\n", - " for t, input_ids in zip(task_list, batch[\"input_ids\"])\n", - " ]\n", - "\n", - " #### get response from gpt2\n", - " response_tensors = []\n", - " for query in query_tensors:\n", - " response = ppo_trainer.generate(query, **generation_kwargs)\n", - " response_tensors.append(response.squeeze()[-txt_out_len:])\n", - " game_data[\"response\"] = [\n", - " gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors\n", - " ]\n", - "\n", - " #### sentiment analysis\n", - " texts = [q + r for q, r in zip(batch[\"query\"], game_data[\"response\"])]\n", - " logits = extract_pipe_output(sentiment_pipe(texts, **sentiment_pipe_kwargs))\n", - " rewards = pos_logit_to_reward(logits, task_list)\n", - "\n", - " #### Run PPO training\n", - " t = time.time()\n", - " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", - "\n", - " for cs in ctrl_str:\n", - " key = \"env/reward_\" + cs.strip(\"[]\")\n", - " stats[key] = np.mean(\n", - " [r.cpu().numpy() for r, t in zip(rewards, task_list) if t == cs]\n", - " )\n", - " ppo_trainer.log_stats(stats, game_data, rewards)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training progress\n", - "If you are tracking the training progress with Weights&Biases you should see a plot similar to the following:\n", - "\n", - "
\n", - "\n", - "

Figure: Reward mean and distribution evolution during training.

\n", - "
\n", - "\n", - "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", - "\n", - "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model inspection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reward distribution\n", - "First, we can have a look at the reward distribution. Both the negative and positive rewards are clearly shifted to high rewards. The neutral rewards, however, are still centered around zero. There are a few possible explanations for this. There could be a bug in the code and the way the neutral rewards are calculated. Another problem could be that sentence sometimes start with a strong sentiment and it is hard for the model shift the sentiment towards neutral." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGzCAYAAAAMr0ziAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPCUlEQVR4nO3deVwVZf8//tecw4HDroiyibKImSmQkKi5B6K3mt4tLvm4RSq7S7lvjTtNLAVcPqip0aLZnbdL3ZK0qP2+5o0SSVmiFor7lklubGqIgB4OnPn9YWfyyGE5h+UM8Ho+Hjw8c80117znOoPzZuaaGUEURRFEREREMqawdABEREREdWHCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkQNtmnTJgiCgNzcXLOWnzZtGnx8fAzKBEFAQkJCg2OrS2ZmJgRBQGZmplQ2dOhQ9OrVq8nXDQC5ubkQBAGbNm1qlvURtVRMWIio1UhJSUFycrKlwzBKzrERtQRWlg6AiMiYO3fuwMrKtP+iUlJScOLECcyePbveywwePBh37tyBtbW1iRGapqbYunbtijt37kClUjXp+olaOp5hIZKBsrIyS4dQK51Oh7t37zbrOtVqtckJiynu3r0LnU4HhUIBtVoNhcIy/x0KggC1Wg2lUmmR9RO1FExYiJpZQkICBEHAqVOn8Nxzz6F9+/YYOHCgNP+///0vQkJCYGtrCxcXF0yaNAmXL1+W5r/77rtQKpUoLi6WylatWgVBEBAbGyuVVVVVwdHREa+//rpUtnLlSgwYMAAdOnSAra0tQkJC8MUXX1SLURAExMTEYMuWLXjkkUdgY2ODtLQ0AMDJkycxfPhw2NraonPnzliyZAl0Ol29t3/Hjh3o1asX1Go1evXqhe3btxut9+AYltu3b2P27Nnw8fGBjY0NOnXqhIiICBw+fBjAvXEnX3/9NX777TcIggBBEKRxMfpxKlu3bsWbb74JLy8v2NnZoaSkxOgYFr3s7GwMGDAAtra28PX1xbp16wzm1zR258E2a4utpjEs3377LQYNGgR7e3u0a9cO48aNw+nTpw3q6PelX375BdOmTUO7du3g7OyM6OholJeX1/wlELVAvCREZCHPPvssAgIC8H//938QRREAsHTpUixYsAATJkzAiy++iKKiIrz33nsYPHgwjhw5gnbt2mHQoEHQ6XT44YcfMGbMGADAvn37oFAosG/fPqn9I0eOoLS0FIMHD5bK3nnnHTz55JOYMmUKKioqsHXrVjz77LPYuXMnRo8ebRDft99+i88++wwxMTFwdXWFj48P8vPzMWzYMFRWVmLevHmwt7fHv//9b9ja2tZrm/fs2YOnn34aPXv2RFJSEm7cuIHo6Gh07ty5zmVffvllfPHFF4iJiUHPnj1x48YN/PDDDzh9+jT69OmDN954A7du3cKVK1fw9ttvAwAcHBwM2li8eDGsra3x2muvQaPR1HoZ6Pfff8df/vIXTJgwAZMnT8Znn32GV155BdbW1nj++efrtb169Yntft988w1GjRoFPz8/JCQk4M6dO3jvvffw+OOP4/Dhw9UGKE+YMAG+vr5ISkrC4cOHsX79enTq1AnLly83KU4iWROJqFnFx8eLAMTJkycblOfm5opKpVJcunSpQfnx48dFKysrqbyqqkp0cnIS586dK4qiKOp0OrFDhw7is88+KyqVSvH27duiKIri6tWrRYVCIf7+++9SW+Xl5QZtV1RUiL169RKHDx9uUA5AVCgU4smTJw3KZ8+eLQIQDx48KJUVFhaKzs7OIgDx4sWLtW57cHCw6OHhIRYXF0tle/bsEQGIXbt2rRZDfHy8NO3s7CzOnDmz1vZHjx5drR1RFMW9e/eKAEQ/P79qfaCft3fvXqlsyJAhIgBx1apVUplGoxGDg4PFTp06iRUVFaIoiuLGjRuNbrexNmuK7eLFiyIAcePGjVKZfj03btyQyo4ePSoqFApx6tSpUpl+X3r++ecN2vzrX/8qdujQodq6iFoyXhIispCXX37ZYHrbtm3Q6XSYMGECrl+/Lv24u7sjICAAe/fuBQAoFAoMGDAA33//PQDg9OnTuHHjBubNmwdRFJGVlQXg3lmXXr16oV27dtI67j8T8vvvv+PWrVsYNGiQdFnlfkOGDEHPnj0Nynbt2oV+/fqhb9++UlnHjh0xZcqUOrc3Ly8POTk5iIqKgrOzs1QeERFRbT3GtGvXDgcPHsS1a9fqrFuTqKioep8NsrKywt///ndp2traGn//+99RWFiI7Oxss2Ooi76fpk2bBhcXF6k8MDAQERER2LVrV7VlHtyXBg0ahBs3bqCkpKTJ4iRqbkxYiCzE19fXYPr8+fMQRREBAQHo2LGjwc/p06dRWFgo1R00aBCys7Nx584d7Nu3Dx4eHujTpw+CgoKky0I//PADBg0aZLCOnTt3ol+/flCr1XBxcUHHjh3xwQcf4NatW3XGBwC//fYbAgICqpU/9NBDdW7vb7/9BgBmL79ixQqcOHEC3t7e6Nu3LxISEvDrr7/Wudz9jG1TTTw9PWFvb29Q1r17dwAw+3kz9aHvJ2N98vDDD+P69evVBml36dLFYLp9+/YA7iWlRK0Fx7AQWciDf+nrdDoIgoD//e9/Ru8YuX/Mw8CBA6HVapGVlYV9+/ZJicmgQYOwb98+nDlzBkVFRQYJy759+/Dkk09i8ODBWLt2LTw8PKBSqbBx40akpKTUGZ+lTZgwAYMGDcL27duxZ88evPXWW1i+fDm2bduGUaNG1auNxt4mQRCMlldVVTXqeupS0x1G4h9jo4haAyYsRDLh7+8PURTh6+sr/SVfk759+8La2hr79u3Dvn37MGfOHAD3niny0UcfISMjQ5rW+/LLL6FWq7F7927Y2NhI5Rs3bqx3jF27dsX58+erlZ89e7ZeywIwe3kA8PDwwIwZMzBjxgwUFhaiT58+WLp0qZSw1JRAmOPatWsoKyszOMty7tw5AJAGverPZNx/xxbw51mS+9U3Nn0/GeuTM2fOwNXVtdqZH6K2gJeEiGTiqaeeglKpRGJiYrW/jEVRxI0bN6RptVqNxx57DJ9++ikuXbpkcIblzp07ePfdd+Hv7w8PDw9pGaVSCUEQDP76z83NxY4dO+od41/+8hccOHAAhw4dksqKioqwZcuWOpf18PBAcHAwNm/ebHAJKj09HadOnap12aqqqmqXrTp16gRPT09oNBqpzN7e3ujlLXNUVlbiww8/lKYrKirw4YcfomPHjggJCQFwL8kEII0n0sf673//u1p79Y3t/n66PxE6ceIE9uzZg7/85S/mbhJRi8YzLEQy4e/vjyVLliAuLg65ubkYP348HB0dcfHiRWzfvh0vvfQSXnvtNan+oEGDsGzZMjg7O6N3794A7h3EH3roIZw9exbTpk0zaH/06NFYvXo1Ro4cieeeew6FhYVYs2YNunXrhmPHjtUrxrlz5+KTTz7ByJEjMWvWLOm25q5du9arjaSkJIwePRoDBw7E888/j5s3b+K9997DI488gtLS0hqXu337Njp37oxnnnkGQUFBcHBwwDfffIOffvoJq1atkuqFhIQgNTUVsbGxeOyxx+Dg4ICxY8fWa9se5OnpieXLlyM3Nxfdu3dHamoqcnJy8O9//1t6Ku0jjzyCfv36IS4uDjdv3oSLiwu2bt2KysrKau2ZEttbb72FUaNGoX///njhhRek25qdnZ2b5f1KRLJkyVuUiNoi/a2oRUVFRud/+eWX4sCBA0V7e3vR3t5e7NGjhzhz5kzx7NmzBvW+/vprEYA4atQog/IXX3xRBCD+5z//qdb2f/7zHzEgIEC0sbERe/ToIW7cuFGK534AaryF+NixY+KQIUNEtVotenl5iYsXLxb/85//1Ou2Zv32Pfzww6KNjY3Ys2dPcdu2bWJUVFSttzVrNBpxzpw5YlBQkOjo6Cja29uLQUFB4tq1aw2WKS0tFZ977jmxXbt2BrdK628z/vzzz6vFU9NtzY888oj4888/i/379xfVarXYtWtX8f3336+2/IULF8Tw8HDRxsZGdHNzE+fPny+mp6dXa7Om2Izd1iyKovjNN9+Ijz/+uGhrays6OTmJY8eOFU+dOmVQp6Z9qabbrYlaMkEUOSqLiIiI5I1jWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcleq3hwnE6nw7Vr1+Do6Nioj+YmIiKipiOKIm7fvg1PT08oFLWfQ2kVCcu1a9fg7e1t6TCIiIjIDJcvX0bnzp1rrdMqEhZHR0cA9zbYycnJ7Ha0Wi327NmDESNGSI/ebovYD+wDgH0AsA/02A/sA6Bp+qCkpATe3t7Scbw2rSJh0V8GcnJyanDCYmdnBycnpza7QwLsB4B9ALAPAPaBHvuBfQA0bR/UZzgHB90SERGR7DFhISIiItljwkJERESy1yrGsNSHKIqorKxEVVVVjXW0Wi2srKxw9+7dWuu1di2xH1QqFZRKpaXDICKiJtImEpaKigrk5eWhvLy81nqiKMLd3R2XL19u089zaYn9IAgCOnfuDAcHB0uHQkRETaDVJyw6nQ4XL16EUqmEp6cnrK2tazwI63Q6lJaWwsHBoc4H2LRmLa0fRFFEUVERrly5goCAAJ5pISJqhVp9wlJRUQGdTgdvb2/Y2dnVWlen06GiogJqtbpFHKibSkvsh44dOyI3NxdarZYJCxFRK9QyjkaNoKUceMk8LeXSFRERmYdHcSIiIpI9JixEREQke61+DEtt3k4/ZzAtiiI0Gg1sbGya5BLDqxHdTao/dOhQfPfddwCAI0eOIDg4uNFjamyCIGD79u0YP358o7SXmZmJYcOGAQDGjRuHHTt2NEq7RETUsvAMi8xNnz4deXl56NWrl6VDMZCQkGA0gcrLy8OoUaMabT0DBgxAXl4eJkyY0GhtEhFRy9Omz7C0BHZ2dnB3d7d0GPXW2LFaW1vD3d0dtra20Gg0jdo2ERG1HDzD0oJkZmZCEARkZGQgNDQUdnZ2GDBgAM6ePWtQ76uvvkKfPn2gVqvh5+eHxMREVFZWSvPPnDmDgQMHQq1Wo2fPnvjmm28gCILB5Zb4+Hj06NEDdnZ28PPzw4IFC6DVagEAmzZtQmJiIo4ePQpBECAIAjZt2gQABu0MGDAAr7/+ukFsRUVFUKlU+P777wEAGo0Gr732Gry8vGBvb4+wsDBkZmY2bscREVGLxzMsLdAbb7yBVatWoWPHjnj55Zfx/PPP48cffwQA7Nu3D1OnTsW7776LQYMG4cKFC3jppZcA3EtCqqqqMH78eHTp0gUHDx7E7du38a9//avaOhwdHbFhwwZ07twZx48fx/Tp0+Ho6Ii5c+di4sSJOHHiBNLS0vDNN98AAJydnau1MWXKFKxYsQLLli2TxgSlpqbC09MTgwYNAgDExMTg1KlT2Lp1Kzw9PbF9+3aMHDkSx48fR0BAQJP0H9Vsbc5a6bOgE+AJT6w/vh6iQrRgVMCM4BkWXT8RWR7PsLRAS5cuxZAhQ9CzZ0/MmzcP+/fvx927dwEAiYmJmDdvHqKiouDn54eIiAgsXrwYH374IQAgPT0dFy5cwMcff4ygoCAMHDgQS5curbaO1157DQMGDICPjw/Gjh2L1157DZ999hkAwNbWFg4ODrCysoK7u7t0yeZBEyZMwLVr1/DDDz9IZSkpKZg8eTIEQcClS5ewceNGfP755xg0aBD8/f3x2muvYeDAgdi4cWNTdB0REbVQPMPSAgUGBkqfPTw8AACFhYXo0qULjh49ih9//NEgCamqqsLdu3dRXl6Os2fPwtvb22CsSd++fautY9u2bfjPf/6DCxcuoLS0FJWVlXBycjIpzo4dO2LEiBHYsmULBg0ahIsXLyIrK0tKno4fP46qqip0725495RGo0GHDh1MWhcREbVuTFhaIJVKJX3WX2rR6XQAgNLSUiQmJuKpp56qtpxara5X+1lZWXjppZeQkJCAkSNHwtnZGVu3bsWqVatMjnXKlCn45z//iffeew8pKSno3bs3evfuLcWqVCqRnZ1d7XH6fIkhERHdz6xLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoM60adOkwZv6n5EjR5oTWpvXp08fnD17Ft26dav2o1Ao8NBDD+Hy5csoKCiQlvnpp58M2sjKyoK3tzfmz5+P0NBQBAQE4LfffjOoY21tjaqqqjrjGTduHO7evYu0tDSkpKRgypQp0rxHH30UVVVVKCwsrBZrS7ozioiImp7JZ1hSU1MRGxuLdevWISwsDMnJyYiMjMTZs2fRqVOnavVdXFzwxhtvoEePHrC2tsbOnTsRHR2NTp06ITIyUqo3cuRIg3ELNjY2Zm5S27Zw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCSIiIuDv74+oqCisWLECt2/fxptvvgngz7M13bp1w5UrV7B161aEhYXh66+/xvbt2w3W4+Pjg4sXLyInJwedO3eGo6Oj0e/M3t4e48ePx4IFC3D69GlMnjxZmte9e3dMmTIFU6dOxapVq/Doo4+iqKgIGRkZCAwMxOjRo5uwp4iIqCUxOWFZvXo1pk+fjujoaADAunXr8PXXX2PDhg2YN29etfpDhw41mJ41axY2b96MH374wSBhsbGxafa/qh988qxOp0NJSQmcnJxa7MsSIyMjsXPnTixatAjLly+HSqVCjx498OKLLwIAlEolduzYgRdffBGPPfYY/Pz88NZbb2Hs2LHSJaMnn3wSr7zyCv75z39Co9Fg9OjRWLBgARISEqT1PP3009i2bRuGDRuG4uJibNy4EdOmTTMa05QpU/CXv/wFgwcPRpcuXQzmbdy4EUuWLMG//vUvXL16Fa6urujXrx/GjBnTJP1DREQtk0kJS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUG8zIzM9GpUye0b98ew4cPx5IlS2oceKnRaAweIlZSUgIA0Gq10rNC9LRaLURRhE6nk8Z51Baf/t+66jaX+2MZPHiwdBlGXxYYGFitLCIiAhEREdXa0s/v3r279BwUANIt0X5+ftDpdBBFEYsWLcLbb79t8IqCf/7zn1IbKpVKumvo/vYfjAW4l0QZKwfuJVDx8fGIj4+vMV59P9T2vejj1mq11cbDmEO/Hz24P7V2gk6o9vn+Mkux1PfQVveDB7Ef2AdA0/SBKW0Jov4oXQ/Xrl2Dl5cX9u/fj/79+0vlc+fOxXfffYeDBw8aXe7WrVvw8vKCRqOBUqnE2rVr8fzzz0vzt27dCjs7O/j6+uLChQuYP38+HBwckJWVZfTgk5CQgMTExGrlKSkpsLOzMyjT33rr7e0Na2vr+m6qLIwZMwaHDh2CtbU1du/ejUceeaRR2t25cyfs7e3h7++PX3/9FXFxcXB2dkZaWlqjtN+Y9u/fjwkTJkCj0Uh3HBlTUVGBy5cvIz8/3+AheUREJF/l5eV47rnncOvWrTrvRG2Wu4QcHR2Rk5OD0tJSZGRkIDY2Fn5+ftLlokmTJkl1e/fujcDAQPj7+yMzMxNPPPFEtfbi4uIQGxsrTZeUlMDb2xsjRoyotsF3797F5cuX4eDgUOddMqIo4vbt23B0dGySlx+a6tNPP8WdO3cAAF26dGm0hKuyshKvv/46Ll26BFdXVzzxxBNYuXKl1Hdy6ochQ4bg8OHDAO7dOVTTDn337l3Y2tpi8ODB9b4bqjZarRbp6emIiIgwuCurtVt/fL30WdAJ8LjqgTyvPIs/OO7F3i9aZL1tdT94EPuBfQA0TR/or5DUh0kJi6urK5RKpcEdJgBQUFBQ6/gThUKBbt26AQCCg4Nx+vRpJCUlVRvfoufn5wdXV1f88ssvRhMWGxsbowM8VSpVtU6sqqqCIAhQKBR1jkvRX27Q17c0b2/vJml32rRpNY43AeTVD/b29tWe02KMQqGAIAhG94GGaOz25M5YYiIqRIsnLJb+DtraflAT9gP7AGjcPjClHZOORtbW1ggJCUFGRoZUptPpkJGRYXCJqC46na7WF9lduXIFN27ckB6KRkRERG2byZeEYmNjERUVhdDQUPTt2xfJyckoKyuT7hqaOnUqvLy8kJSUBABISkpCaGgo/P39odFosGvXLnzyySf44IMPAPz5oLOnn34a7u7uuHDhAubOnYtu3boZ3EVEREREbZfJCcvEiRNRVFSEhQsXIj8/H8HBwUhLS4ObmxsA4NKlSwaXEcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJAO7dJXLs2DFs3rwZxcXF8PT0xIgRI7B48WI+i4WIiIgAmDnoNiYmBjExMUbnZWZmGkwvWbIES5YsqbEtW1tb7N6925wwiIiIqI2w/MhSIiIiojowYSEiIiLZa9tva96bZDApiCLUGg0EGxugKZ4/Miyu7jr3GTp0KL777jsAwJEjRxAcHNz4MTWDTZs2Yfbs2SguLpam9YO0Z82aheTkZMsFR0RELQLPsMjc9OnTkZeXh169ejXbOjMzM9G+fXspwWhsEydORF5enkm3whMRUdvWts+wtAB2dnbN/lLI+qqoqDDr6bu2trawtbVtca9KICIiy+EZlhYkMzMTgiAgIyMDoaGhsLOzw4ABA3D27FmDel999RX69OkDtVoNPz8/JCYmSu/Xyc3NhSAIyMnJkeoXFxdDEARkZmYiNzdXerpwhw4dIAiC9FTcoUOHIiYmBrNnz4arq6v0nJzVq1ejd+/esLe3h7e3N2bMmIHS0tKm7xAiImozmLC0QG+88QZWrVqFn3/+GVZWVgYvkty3bx+mTp2KWbNm4dSpU/jwww+xadMmLF26tF5te3t74/PPPwcAnD59Gnl5eXjnnXek+Zs3b4a1tTV+/PFHrFu3DsC9x+K/++67OHnyJDZv3oxvv/0Wc+fObcQtJiKito6XhFqgpUuXYsiQIQCAefPmYfTo0bh79y7UajUSExMxb948REVFAbj3XqbFixdj7ty5iI+Pr7NtpVIJFxcXAECnTp2kz3oBAQFYsWKFQdns2bOlzz4+PliyZAlefvllrF27tiGbSUREJGHC0gIFBgZKn/XvWyosLESXLl1w9OhR/PjjjwZnVKqqqnD37l2Ul5c3eN0hISHVyr755hskJSXhzJkzKCkpQWVlpbQ+Ozu7Bq+TiIiICUsLdP/bLYU/br/Wv2FZ/26mp556qtpyarVaem2CKP759l2tVlvvddvb2xtM5+bmYsyYMXjllVewdOlSuLi44IcffsALL7yAiooKJixERNQomLC0Mn369MHZs2fRrVs3o/M7duwIAMjLy8Ojjz4KAAYDcAFId+9UVVXVub7s7GzodDqsWrVKSoY+++wzc8MnIiIyiglLK7Nw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCWxtbdGvXz8sW7YMvr6+KCwsxJtvvmnQRteuXSEIAnbu3IkxY8bA1tYWDg4ORtfXrVs3aLVavPfeexg7dqzBYFwiIqLG0rYTlgeePCvqdLhbUgJrJycIipZ5A1VkZCR27tyJRYsWYfny5VCpVOjRowdefPFFqc6GDRvwwgsvICQkBA899BBWrFiBESNGSPO9vLwQFxeH+fPn44UXXsDUqVOxadMmo+sLCgrC6tWrsXz5csTFxWHw4MFISkrC1KlTm3pTiYioDWnbCUsLM3ToUIOxJwAQHBxcrSwyMlJ6RooxDz/8MPbv329Q9mAbc+bMweLFi6XLPED1N3Hrvfrqq3j11VcNyv72t79Jn6dNmyY9y4WIiMgcLfM0Qhuydu1aODg44Pjx45YOpdFs2bIFDg4O2Ldvn6VDISKiFoJnWGRsy5YtuHPnDgCgS5cuFo6m8Tz55JMICwsDALRr186ywRARUYvAhEXGvLy8LB1Ck3B0dISjo6OlwyAiohaEl4SIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPdwkREdXH3iRLR2DcA0/spnoy9fsUFQB6APtWA4KuSUICwO+zFm06YVmbs9ZgWhRFaDQa2NjYSG9BbkwzgmeYVH/o0KH47rvvAABHjhxBcHBwo8dkbJ1BQUFITEyssc6mTZswe/ZsFBcXN9p6p02bhs2bNwMAtm/fjvHjxzda20RE1PLxkpDMTZ8+HXl5eejVq1ezrG/btm1YtGiRNO3j44Pk5GSDOhMnTsS5c+cadb3vvPMO8vLyGrVNIiJqPdr0GZaWwM7ODu7u7s22PhcXF+h0OpSUlNRYx9bWFra2to26XmdnZzg7Ozdqm0RE1HrwDEsLkpmZCUEQ8PXXXyMwMBBqtRr9+vXDiRMnDOp9+eWXeOSRR2BjYwMfHx+sWrXKYP7atWsREBAAtVoNNzc3PPPMM9K8oUOHSi8yHD58OH777Te8+uqrEARBuky2adMm6ZH6586dgyAIOHPmjME63n77bfj7+0vTJ06cwKhRo+Dg4AA3Nzf87W9/w/Xr1xutb4iIqHVjwtICzZkzB6tWrcJPP/2Ejh07YuzYsdBqtQCA7OxsTJgwAZMmTcLx48eRkJCABQsWYNOmTQCAn3/+Gf/85z+xaNEinD17FmlpaRg8eLDR9XzxxRfo3LkzFi1ahLy8PKOXbLp3747Q0FBs2bLFoHzLli147rnnAADFxcUYPnw4Hn30Ufz8889IS0tDQUEBJkyY0Ii9QkRErRkvCbVA8fHxiIiIAABs3rwZnTt3xvbt2zFhwgSsXr0aTzzxBBYsWADgXkJx6tQpvPXWW5g2bRouXboEe3t7jBkzBo6OjujatSseffRRo+txcXGBUqmEo6NjrZelpkyZgvfffx+LFy8GcO+sS3Z2Nv773/8CAN5//308+uij+L//+z9pmQ0bNsDb2xvnzp1D9+7dG6VfiIio9eIZlhaof//+0mcXFxc89NBDOH36NADg9OnTePzxxw3qP/744zh//jyqqqoQERGBrl27ws/PD3/729+wZcsWlJeXNyieSZMmITc3FwcOHABw7+xKnz590KNHDwDA0aNHsXfvXjg4OEg/+nkXLlxo0LqJiKhtYMLSxjg6OuLw4cP49NNP4eHhgYULFyIoKKhBtyi7u7tj+PDhSElJAQCkpKRgypQp0vzS0lKMHTsWOTk5Bj/nz5+v8XIUERHR/ZiwtED6MxkA8Pvvv+PcuXN4+OGHAQAPP/wwfvzxR4P6P/74I7p37w6lUgkAsLKyQnh4OFasWIFjx44hNzcX3377rdF1WVtbo6qqqs6YpkyZgtTUVGRlZeHXX3/FpEmTpHl9+vTByZMn4ePjg27duhn82Nvbm7z9RETU9jBhaYEWLVqEjIwMnDhxAtOmTYOrq6v0oLV//etfyMjIwOLFi3Hu3Dls3rwZ77//Pl577TUAwM6dO/Huu+8iJycHv/32Gz7++GPodDo89NBDRtfl4+OD77//HlevXq31rp6nnnoKt2/fxiuvvIJhw4bB09NTmjdz5kzcvHkTkydPxk8//YQLFy5g9+7diI6OrlcyRERE1KYH3T745Fn980ecnJygUMg3l1u2bBlmzZqF8+fPIzg4GP/v//0/WFtbA7h3NuOzzz7DwoULsXjxYnh4eGDRokWYNm0aAKBdu3bYtm0bEhIScPfuXQQEBODTTz/FI488YnRdixYtwt///nf4+/tDo9FAFEWj9RwdHTF27Fh89tln2LBhg8E8T09P/Pjjj3j99dcxYsQIaDQadO3aFSNHjpR1PxMRkXy06YSlpRo4cGC1Z6/c7+mnn8bTTz9d47KZmZk1LpuZmWnw4Lh+/frh6NGjBnWmTZsmJUD3S01NRWpqqtF2AwICsG3bthrXS0REVBv+eStza9euhYODA44fP27pUJrUyy+/DAcHB0uHQUREMsUzLDK2ZcsW3LlzBwDQpUsX7N+/38IRNZ1FixZJ42w8PDwsHA0REckNExYZ8/LyMpgeOnRojWNIWrpOnTqhU6dOlg6DiIhkyqxLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoI4oili4cCE8PDxga2uL8PBwnD9/3pzQiIiIqBUyOWFJTU1FbGws4uPjcfjwYQQFBSEyMhKFhYVG67u4uOCNN95AVlYWjh07hujoaERHR2P37t1SnRUrVuDdd9/FunXrcPDgQdjb2yMyMhJ37941f8se0FrPTNA9/H6JiFo3kxOW1atXY/r06YiOjkbPnj2xbt062NnZVbuVVW/o0KH461//iocffhj+/v6YNWsWAgMD8cMPPwC4d6BJTk7Gm2++iXHjxiEwMBAff/wxrl27hh07djRo4wBApVIBQIMfP0/yVlFRAQDSw/GIiKh1MWkMS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUAgIsXLyI/Px/h4eFSPWdnZ4SFhSErK8vgial6Go0GGo1GmtbfgqvVaqW3Ft/P0dERBQUF0Ol0sLOzgyAINcZXUVGBO3fu1FinLWhp/aDT6VBYWAi1Wg1RFI3uA6bSt9EYbbUkgk6o9vn+Mkux1PdgsB+IMr2pshn6plX+Ppj4fWr/qK9t6v1Axn3cFPuBKW2ZlLBcv34dVVVVcHNzMyh3c3PDmTNnalzu1q1b8PLygkajgVKpxNq1a6W3Defn50ttPNimft6DkpKSkJiYWK18z549sLOzM7qMo6MjysrK+KCyVkqr1aKoqAjHjh1r1HbT09MbtT2584RntTKPq5a/a2vX5V0WXf+9/aCHRWOo0a7m65vW9ftg3veZXtrEb5dvxu/TXI25H5hy9aNZ7hJydHRETk4OSktLkZGRgdjYWPj5+WHo0KFmtRcXF4fY2FhpuqSkBN7e3hgxYgScnJxqXK6qqgqVlZU1jneorKzE/v37MWDAAFhZtd0bqFpaPwiCAJVK1ajJqFarRXp6OiIiIqTLim3B+uPrpc+CToDHVQ/keeVBVFh2jNCLvV+0yHoN9oMD71kkhjoNiq27TgO1yt+HfatNqq4VFUgv7Y4Ih3NQCbomCgrN8n2aqyn2A/0Vkvow6Wjk6uoKpVKJgoICg/KCggK4u7vXuJxCoUC3bt0AAMHBwTh9+jSSkpIwdOhQabmCggKD528UFBQgODjYaHs2NjawsbGpVq5SqWrtxLo6WKvVorKyEg4ODq3nl9IM7Ic/1bVPtTbGEhNRIVo8YbH0d6BSqZr2INUQzdg3rer3wczvUyXomnZfaAH925j7gSntmPQnqbW1NUJCQpCRkSGV6XQ6ZGRkoH///vVuR6fTSWNQfH194e7ubtBmSUkJDh48aFKbRERE1HqZfL4/NjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDfeJDQ0VHp53q5du/DJJ5/ggw8+AHDvdP7s2bOxZMkSBAQEwNfXFwsWLICnp6f0BmIiIiJq20xOWCZOnIiioiIsXLgQ+fn5CA4ORlpamjRo9tKlSwZjCcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJUp25c+eirKwML730EoqLizFw4ECkpaVBrVY3wiYSERFRS2fWiMqYmBjExMQYnffgm4CXLFmCJUuW1NqeIAhYtGgRFi1aZE44RERE1MrxHl8iIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2zHqXEFFr8Xb6OaPlglgFXwBr9v4CUVA2a0yvRnRv1vUREbUEPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0rSwdARIbeTj9nsXUfLrkhfbaCAuOtPC0WCxHR/XiGhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9sxKWNasWQMfHx+o1WqEhYXh0KFDNdb96KOPMGjQILRv3x7t27dHeHh4tfrTpk2DIAgGPyNHjjQnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqQb2RI0ciLy9P+vn000/N2yIiIiJqdUxOWFavXo3p06cjOjoaPXv2xLp162BnZ4cNGzYYrb9lyxbMmDEDwcHB6NGjB9avXw+dToeMjAyDejY2NnB3d5d+2rdvb94WERERUatj0oPjKioqkJ2djbi4OKlMoVAgPDwcWVlZ9WqjvLwcWq0WLi4uBuWZmZno1KkT2rdvj+HDh2PJkiXo0KGD0TY0Gg00Go00XVJSAgDQarXQarWmbJIB/bINaaM1aEv9IIhVtZbXNL+1srrvbxj9Z0EnWCociaX2RYPfBVGmQ/6aoW9a5f8JJn6f2j/qa5t6P5BxHzfFfmBKW4IoimJ9K1+7dg1eXl7Yv38/+vfvL5XPnTsX3333HQ4ePFhnGzNmzMDu3btx8uRJqNVqAMDWrVthZ2cHX19fXLhwAfPnz4eDgwOysrKgVCqrtZGQkIDExMRq5SkpKbCzs6vv5hAREZEFlZeX47nnnsOtW7fg5ORUa91mfTT/smXLsHXrVmRmZkrJCgBMmjRJ+ty7d28EBgbC398fmZmZeOKJJ6q1ExcXh9jYWGm6pKREGhtT1wbXRqvVIj09HREREVCpVGa309K1pX5Ys/cXo+WCWAWfuxeQq/aHKFRPmluro7e3SZ+toMAYq57I88qDqKj33zVN4sXeL1pkvQa/Cwfes0gMdRoUW3edBmqV/yfsW21Sda2oQHppd0Q4nINK0DVRUPKm7fePRt8P9FdI6sOkhMXV1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmX45ptvEBgYWGtdPz8/uLq64pdffjGasNjY2MDGxqZauUqlapRObKx2Wrq20A91JSOioGxTCUslqv9HLCpEiycslt4PVSqVfA9Szdg3rer/BDO/T5Wgk+++0NT++O4bcz8wpR2TLsZZW1sjJCTEYMCsfgDt/ZeIHrRixQosXrwYaWlpCA0NrXM9V65cwY0bN+Dh4WFKeERERNRKmTx6KDY2Fh999BE2b96M06dP45VXXkFZWRmio6MBAFOnTjUYlLt8+XIsWLAAGzZsgI+PD/Lz85Gfn4/S0lIAQGlpKebMmYMDBw4gNzcXGRkZGDduHLp164bIyMhG2kwiIiJqyUwewzJx4kQUFRVh4cKFyM/PR3BwMNLS0uDm5gYAuHTpEhSKP/OgDz74ABUVFXjmmWcM2omPj0dCQgKUSiWOHTuGzZs3o7i4GJ6enhgxYgQWL15s9LIPERERtT1mDbqNiYlBTEyM0XmZmZkG07m5ubW2ZWtri927d5sTBhEREbURMn2wABEREdGfmLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItmzsnQA1LqszVlr6RCqmRE8w9IhUCu0tviYpUO4pxl+5wSdAE94Yv3x9RAVosnL83eQGgPPsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI93tZMzSrrwo1mX6em6Fyzr7M1OXTxJiqhs2gM93+Hr0Z0t2AkRGQpPMNCREREsseEhYiIiGSPl4SIiKhpXNx379/fb1k2DmoVeIaFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZM+shGXNmjXw8fGBWq1GWFgYDh06VGPdjz76CIMGDUL79u3Rvn17hIeHV6sviiIWLlwIDw8P2NraIjw8HOfPnzcnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqVGfFihV49913sW7dOhw8eBD29vaIjIzE3bt3zd8yIiIiajVMfpfQ6tWrMX36dERHRwMA1q1bh6+//hobNmzAvHnzqtXfsmWLwfT69evx5ZdfIiMjA1OnToUoikhOTsabb76JcePGAQA+/vhjuLm5YceOHZg0aVK1NjUaDTQajTRdUlICANBqtdBqtaZukkS/bEPaaA0a0g+CTqh1vpUFrkIKYpXZy5izbEt2//ej/2yJ7+xB938Pzfn7afC7IBr2gyCXV7HV8TvXGPS/13X9fld3r4+0ouX3oYbSb0Nr2BZzNcUx0pS2BFEUxfpWrqiogJ2dHb744guMHz9eKo+KikJxcTG++uqrOtu4ffs2OnXqhM8//xxjxozBr7/+Cn9/fxw5cgTBwcFSvSFDhiA4OBjvvPNOtTYSEhKQmJhYrTwlJQV2dnb13RwiIiKyoPLycjz33HO4desWnJycaq1r0p8I169fR1VVFdzc3AzK3dzccObMmXq18frrr8PT0xPh4eEAgPz8fKmNB9vUz3tQXFwcYmNjpemSkhLpUlNdG1wbrVaL9PR0REREQKVSmd1OS9eQflh/fH2t8w9dvNmQ0MwS5PiUycsIYhV87l5ArtofoqBsgqjk6ejtbdJnKygwxqondlaeQiV0FozK8DucOaxbs63X4HfhwHsG89bfOtFscdSq64AmX4WgE+Bx1QN5XnkQFfX+Gxf4bT8A4EXnXk0UWfPRigqkl3ZHhMM5qATL/j5YirbfPxr9GKm/QlIfzXpOc9myZdi6dSsyMzOhVqvNbsfGxgY2NjbVylUqVaN0YmO109KZ0w91/WdmiQNfQxIOUVC2qYTF2PdTCZ3FE5b7vwNL/G6qVKpqBykRlc0eh1GmJBANJCpE0xKWP/qoNR3gVYKuVW2PSf743WvMY6Qp7Zh0Mc7V1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmXYs2cPAgMDpXL9cua0SURERG2DSQmLtbU1QkJCkJGRIZXpdDpkZGSgf//+NS63YsUKLF68GGlpaQgNDTWY5+vrC3d3d4M2S0pKcPDgwVrbJCIiorbD5EtCsbGxiIqKQmhoKPr27Yvk5GSUlZVJdw1NnToVXl5eSEpKAgAsX74cCxcuREpKCnx8fKRxKQ4ODnBwcIAgCJg9ezaWLFmCgIAA+Pr6YsGCBfD09DQY2EtERERtl8kJy8SJE1FUVISFCxciPz8fwcHBSEtLkwbNXrp0CQrFnyduPvjgA1RUVOCZZ54xaCc+Ph4JCQkAgLlz56KsrAwvvfQSiouLMXDgQKSlpTVonAsRERG1HmYNuo2JiUFMTIzReZmZmQbTubm5dbYnCAIWLVqERYsWmRMOERERtXJt9wk4RERE1GIwYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnllvayYiak6HS1Klz2tzOjTbegWdAE94Yv3x9RCLjzXbeomoOp5hISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPd4lRETUAJeL71h0/Vcu3GjydVhBgfFWnjh08SYqoav3cp1L/uibdk0TF7UtPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGTPrIRlzZo18PHxgVqtRlhYGA4dOlRj3ZMnT+Lpp5+Gj48PBEFAcnJytToJCQkQBMHgp0ePHuaERkRERK2QyQlLamoqYmNjER8fj8OHDyMoKAiRkZEoLCw0Wr+8vBx+fn5YtmwZ3N3da2z3kUceQV5envTzww8/mBoaERERtVJWpi6wevVqTJ8+HdHR0QCAdevW4euvv8aGDRswb968avUfe+wxPPbYYwBgdL4UiJVVrQlNS/B2+jlLh1DNqxHdLR0CERFRg5mUsFRUVCA7OxtxcXFSmUKhQHh4OLKyshoUyPnz5+Hp6Qm1Wo3+/fsjKSkJXbp0MVpXo9FAo9FI0yUlJQAArVYLrVZrdgz6Zc1tQxCrzF53UzFnWxrSD4JOqHW+lQWGTZnzveiXkeN32pTu/370ny3xndWmrn2sKdZ171/j/10qoWq2eIxpju/H3H1B3zdaUV77kDn029AatsVcDT1G1tZmfQiiKIr1rXzt2jV4eXlh//796N+/v1Q+d+5cfPfddzh48GCty/v4+GD27NmYPXu2Qfn//vc/lJaW4qGHHkJeXh4SExNx9epVnDhxAo6OjtXaSUhIQGJiYrXylJQU2NnZ1XdziIiIyILKy8vx3HPP4datW3Bycqq1rsmXhJrCqFGjpM+BgYEICwtD165d8dlnn+GFF16oVj8uLg6xsbHSdElJCby9vTFixIg6N7g2Wq0W6enpiIiIgEpl+l9Na/b+Yva6m8rMYd1MXqYh/bD++Ppa5x+6eNPkeBoqyPEpk5cRxCr43L2AXLU/REHZBFHJ09Hb26TPVlBgjFVP7Kw8hUroLBiVob6+Ls22LkEnwOOqB/K88iBe/tFonavFd5stHqPrdwpu8nWYuy94leQAAOK7hjZRZM1HKyqQXtodEQ7noBLk8/vQnLT9/tGgY6Qx+isk9WFSwuLq6gqlUomCggKD8oKCgkYdf9KuXTt0794dv/xiPAGwsbGBjY1NtXKVStUonWhuO3I8sDWkP8zpB1FR+wk7Sxz4GvK9iIJSlt9rUzH2/VRCJ6uEpa59rKnWKaLS6LwqNN7pcXM053dj6r6g75vWdIBXCbpWtT0m+eN40FjHWn1b9WXSxThra2uEhIQgIyNDKtPpdMjIyDC4RNRQpaWluHDhAjw8PBqtTSIiImq5TL4kFBsbi6ioKISGhqJv375ITk5GWVmZdNfQ1KlT4eXlhaSkJAD3BuqeOnVK+nz16lXk5OTAwcEB3brdu1zx2muvYezYsejatSuuXbuG+Ph4KJVKTJ48ubG2k4iIiFowkxOWiRMnoqioCAsXLkR+fj6Cg4ORlpYGNzc3AMClS5egUPx54ubatWt49NFHpemVK1di5cqVGDJkCDIzMwEAV65cweTJk3Hjxg107NgRAwcOxIEDB9CxY8cGbh4RERG1BmYNuo2JiUFMTIzRefokRM/Hxwd13Yi0detWc8IgIiKiNkIWdwkRmatzSXaddfoV3zK5XZ2gxPUOA/DYlU1QtMJnsRzo8pKlQyAiMknbfQIOERERtRhMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcken8NCrd7/pzD9LdpKqPAYBuB/wq+oEprm5XZP6kx/kzYRUVvFMyxEREQke0xYiIiISPaYsBAREZHscQwLERE1qaxfb1g6BAP9/TpYOgQyA8+wEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkexZWToAorbq/1P8YrF1XylJtdi6qe2x5L5uzJHiq5jRLtDSYZCJeIaFiIiIZI8JCxEREckeExYiIiKSPY5hIaIWJevCjWZblxUUGG/liUMXb8K95E6zrZeIquMZFiIiIpI9JixEREQke0xYiIiISPbMSljWrFkDHx8fqNVqhIWF4dChQzXWPXnyJJ5++mn4+PhAEAQkJyc3uE0iIiJqW0xOWFJTUxEbG4v4+HgcPnwYQUFBiIyMRGFhodH65eXl8PPzw7Jly+Du7t4obRIREVHbYnLCsnr1akyfPh3R0dHo2bMn1q1bBzs7O2zYsMFo/cceewxvvfUWJk2aBBsbm0Zpk4iIiNoWk25rrqioQHZ2NuLi4qQyhUKB8PBwZGVlmRWAOW1qNBpoNBppuqSkBACg1Wqh1WrNikO//P3/mkoQq8xed1MxZ1sa0g+CTqh1vlUjD5tSQtWo7ekp/mhX0UTtW1p9vgd9ncb+zlqS+/ugqfa1hmqO78fcfUGufSbAClrRtG3R1zd1udakocfI2tqsD5MSluvXr6Oqqgpubm4G5W5ubjhz5owpTTWozaSkJCQmJlYr37NnD+zs7MyK437p6elmLefb4DU3vl27zpm9rDn94AnPWuePt6p9vslcejVuew8IcZnYpO1bymMm1B1j1bPJ4mgpxlj1BFzk2Q+mfJcNZfK+0MS/nw2x67Z5y6WXdm/cQFqSP44J5h4jjSkvL6933Rb54Li4uDjExsZK0yUlJfD29saIESPg5ORkdrtarRbp6emIiIiASmX6XwZr9srrBV8AMHNYN5OXaUg/rD++vtb5hy7eNDme2niV5DRqe3oKqBDiMhHZN1OhQ+P9NSEXV52C66xjBQXGWPXEzspTqISu6YOSofv7wK3ksKXDMao+32VDmbsvNNXvZ0N5tVPjRWfTkimtqEB6aXdEOJyDSmibvw/afv9o0DHSGP0VkvowKWFxdXWFUqlEQUGBQXlBQUGNA2qbok0bGxuj42FUKlWjdKK57YiCssHrbmwN6Q9z+kFUiLXOb+wDX1UTJxM6aJt8HZZgyvdQCV2bTVj0KqGT7X7QnN+NqfuCXPtMhJXZSYdK0LXZhAV/HA8a61irb6u+TLoYZ21tjZCQEGRkZEhlOp0OGRkZ6N+/vylNNWmbRERE1LqYfEkoNjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDeo9tSpU9Lnq1evIicnBw4ODujWrVu92iQiIqK2zeSEZeLEiSgqKsLChQuRn5+P4OBgpKWlSYNmL126BIXizxM3165dw6OPPipNr1y5EitXrsSQIUOQmZlZrzaJiIiobTNr0G1MTAxiYmKMztMnIXo+Pj4QxdrHNdTVJhEREbVtbfeGciIiImoxmLAQERGR7DFhISIiItlrkQ+Oa25rc9bWq97hkhtNHInp1uZ0MHkZQSfAE55Yf3x9nc9VISIiag5MWIhINjqXZFs6BANKqACXXvAqyYH83hR2T3P0mWE/yPNhcNT68ZIQERERyR7PsLRyWRdMv0xlBQXGW3ni0MWbbf6R7EREJA88w0JERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2rCwdABE1v84l2XXWUUIFuPSCV0kOqqBthqiIiGrGMyxEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9sxKWNWvWwMfHB2q1GmFhYTh06FCt9T///HP06NEDarUavXv3xq5duwzmT5s2DYIgGPyMHDnSnNCIiIioFbIydYHU1FTExsZi3bp1CAsLQ3JyMiIjI3H27Fl06tSpWv39+/dj8uTJSEpKwpgxY5CSkoLx48fj8OHD6NWrl1Rv5MiR2LhxozRtY2Nj5iZRU+hckm3pEIiIqA0z+QzL6tWrMX36dERHR6Nnz55Yt24d7OzssGHDBqP133nnHYwcORJz5szBww8/jMWLF6NPnz54//33DerZ2NjA3d1d+mnfvr15W0REREStjklnWCoqKpCdnY24uDipTKFQIDw8HFlZWUaXycrKQmxsrEFZZGQkduzYYVCWmZmJTp06oX379hg+fDiWLFmCDh06GG1To9FAo9FI0yUlJQAArVYLrVZryiYZ0C/7YBuCTqjX8latZEiQfjvu3x4lVJYKxyIUf2yvoo1t9/3YB+wDvdbWDwKsoBVN+/9aX9/U5VqTmo6RjdFmfZiUsFy/fh1VVVVwc3MzKHdzc8OZM2eMLpOfn2+0fn5+vjQ9cuRIPPXUU/D19cWFCxcwf/58jBo1CllZWVAqldXaTEpKQmJiYrXyPXv2wM7OzpRNMio9Pd1g2hOe9VpuvFX96rUUY6x6/jnh0qvmiq1YiMtES4dgcewD9oFea+qHXbfNWy69tHvjBtKS/HFsfPAY2RDl5eX1rmvyGJamMGnSJOlz7969ERgYCH9/f2RmZuKJJ56oVj8uLs7grE1JSQm8vb0xYsQIODk5mR2HVqtFeno6IiIioFL9+ZfE+uPr67X8oYs3zV63nFhBgTFWPbGz8hQqoQMAeJXkWDaoZqaACiEuE5F9MxU6NN5fEy0J+4B9oNfa+sGrnRovOpv2R5hWVCC9tDsiHM5BJeiaKDJ50/b7h9FjZEPor5DUh0kJi6urK5RKJQoKCgzKCwoK4O7ubnQZd3d3k+oDgJ+fH1xdXfHLL78YTVhsbGyMDspVqVSN0okPtiMqxHotpz+4txaV0EnbVNUK/pMyhw7aNrvteuwD9oFea+kHEVZmJx0qQddmExb8cVxsrGOtvq36MulinLW1NUJCQpCRkSGV6XQ6ZGRkoH///kaX6d+/v0F94N7ppJrqA8CVK1dw48YNeHh4mBIeERERtVImjx6KjY3FRx99hM2bN+P06dN45ZVXUFZWhujoaADA1KlTDQblzpo1C2lpaVi1ahXOnDmDhIQE/Pzzz4iJiQEAlJaWYs6cOThw4AByc3ORkZGBcePGoVu3boiMjGykzSQiIqKWzOQxLBMnTkRRUREWLlyI/Px8BAcHIy0tTRpYe+nSJSgUf+ZBAwYMQEpKCt58803Mnz8fAQEB2LFjh/QMFqVSiWPHjmHz5s0oLi6Gp6cnRowYgcWLF/NZLERERATAzEG3MTEx0hmSB2VmZlYre/bZZ/Hss88arW9ra4vdu3ebEwYRERG1EW33hnIiIiJqMZiwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnlkJy5o1a+Dj4wO1Wo2wsDAcOnSo1vqff/45evToAbVajd69e2PXrl0G80VRxMKFC+Hh4QFbW1uEh4fj/Pnz5oRGRERErZDJCUtqaipiY2MRHx+Pw4cPIygoCJGRkSgsLDRaf//+/Zg8eTJeeOEFHDlyBOPHj8f48eNx4sQJqc6KFSvw7rvvYt26dTh48CDs7e0RGRmJu3fvmr9lRERE1GqYnLCsXr0a06dPR3R0NHr27Il169bBzs4OGzZsMFr/nXfewciRIzFnzhw8/PDDWLx4Mfr06YP3338fwL2zK8nJyXjzzTcxbtw4BAYG4uOPP8a1a9ewY8eOBm0cERERtQ5WplSuqKhAdnY24uLipDKFQoHw8HBkZWUZXSYrKwuxsbEGZZGRkVIycvHiReTn5yM8PFya7+zsjLCwMGRlZWHSpEnV2tRoNNBoNNL0rVu3AAA3b96EVqs1ZZMMaLValJeX48aNG1CpVFL53ZL6nenRlVeYvW450UGBcqty6CoroIMOAFB5x8JBNTMdgPLycmjv4I8eaHvYB+wDvdbWD3etdbhhZdr/11pRce/4IFRAJbSGXjCd9sYNo8fIhrh9+zaAeycv6mJSwnL9+nVUVVXBzc3NoNzNzQ1nzpwxukx+fr7R+vn5+dJ8fVlNdR6UlJSExMTEauW+vr712xCq08eWDkAWvrB0ADLAPmAf6LWufviXpQNokRKarOXbt2/D2dm51jomJSxyERcXZ3DWRqfT4ebNm+jQoQMEQTC73ZKSEnh7e+Py5ctwcnJqjFBbJPYD+wBgHwDsAz32A/sAaJo+EEURt2/fhqenZ511TUpYXF1doVQqUVBQYFBeUFAAd3d3o8u4u7vXWl//b0FBATw8PAzqBAcHG23TxsYGNjY2BmXt2rUzZVNq5eTk1GZ3yPuxH9gHAPsAYB/osR/YB0Dj90FdZ1b0TBp0a21tjZCQEGRkZEhlOp0OGRkZ6N+/v9Fl+vfvb1AfANLT06X6vr6+cHd3N6hTUlKCgwcP1tgmERERtS0mXxKKjY1FVFQUQkND0bdvXyQnJ6OsrAzR0dEAgKlTp8LLywtJSUkAgFmzZmHIkCFYtWoVRo8eja1bt+Lnn3/Gv//9bwCAIAiYPXs2lixZgoCAAPj6+mLBggXw9PTE+PHjG29LiYiIqMUyOWGZOHEiioqKsHDhQuTn5yM4OBhpaWnSoNlLly5BofjzxM2AAQOQkpKCN998E/Pnz0dAQAB27NiBXr16SXXmzp2LsrIyvPTSSyguLsbAgQORlpYGtVrdCJtYfzY2NoiPj692uamtYT+wDwD2AcA+0GM/sA8Ay/eBINbnXiIiIiIiC+K7hIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JSy2efPJJdOnSBWq1Gh4eHvjb3/6Ga9euWTqsZpObm4sXXngBvr6+sLW1hb+/P+Lj41FR0Tpe8lhfS5cuxYABA2BnZ9eoT1SWuzVr1sDHxwdqtRphYWE4dOiQpUNqVt9//z3Gjh0LT09PCILQ5t4en5SUhMceewyOjo7o1KkTxo8fj7Nnz1o6rGb3wQcfIDAwUHq6a//+/fG///3P0mFZ1LJly6RnqDUnJiy1GDZsGD777DOcPXsWX375JS5cuIBnnnnG0mE1mzNnzkCn0+HDDz/EyZMn8fbbb2PdunWYP3++pUNrVhUVFXj22WfxyiuvWDqUZpOamorY2FjEx8fj8OHDCAoKQmRkJAoLCy0dWrMpKytDUFAQ1qxZY+lQLOK7777DzJkzceDAAaSnp0Or1WLEiBEoKyuzdGjNqnPnzli2bBmys7Px888/Y/jw4Rg3bhxOnjxp6dAs4qeffsKHH36IwMDA5l+5SPX21VdfiYIgiBUVFZYOxWJWrFgh+vr6WjoMi9i4caPo7Oxs6TCaRd++fcWZM2dK01VVVaKnp6eYlJRkwagsB4C4fft2S4dhUYWFhSIA8bvvvrN0KBbXvn17cf369ZYOo9ndvn1bDAgIENPT08UhQ4aIs2bNatb18wxLPd28eRNbtmzBgAEDoFKpLB2Oxdy6dQsuLi6WDoOaUEVFBbKzsxEeHi6VKRQKhIeHIysry4KRkSXdunULANr0739VVRW2bt2KsrKyNvmuu5kzZ2L06NEG/zc0JyYsdXj99ddhb2+PDh064NKlS/jqq68sHZLF/PLLL3jvvffw97//3dKhUBO6fv06qqqqpNdt6Lm5uSE/P99CUZEl6XQ6zJ49G48//rjBa1XaiuPHj8PBwQE2NjZ4+eWXsX37dvTs2dPSYTWrrVu34vDhw9J7Ai2hzSUs8+bNgyAItf6cOXNGqj9nzhwcOXIEe/bsgVKpxNSpUyG28LcZmNoHAHD16lWMHDkSzz77LKZPn26hyBuPOX1A1FbNnDkTJ06cwNatWy0dikU89NBDyMnJwcGDB/HKK68gKioKp06dsnRYzeby5cuYNWsWtmzZ0uzv+Ltfm3uXUFFREW7cuFFrHT8/P1hbW1crv3LlCry9vbF///4WfTrQ1D64du0ahg4din79+mHTpk0GL7dsqczZDzZt2oTZs2ejuLi4iaOzrIqKCtjZ2eGLL74weGN6VFQUiouL2+RZRkEQsH379jb5BvmYmBh89dVX+P777+Hr62vpcGQhPDwc/v7++PDDDy0dSrPYsWMH/vrXv0KpVEplVVVVEAQBCoUCGo3GYF5TMfltzS1dx44d0bFjR7OW1el0AACNRtOYITU7U/rg6tWrGDZsGEJCQrBx48ZWkawADdsPWjtra2uEhIQgIyNDOkDrdDpkZGQgJibGssFRsxFFEf/4xz+wfft2ZGZmMlm5j06na/HHAVM88cQTOH78uEFZdHQ0evTogddff71ZkhWgDSYs9XXw4EH89NNPGDhwINq3b48LFy5gwYIF8Pf3b9FnV0xx9epVDB06FF27dsXKlStRVFQkzXN3d7dgZM3r0qVLuHnzJi5duoSqqirk5OQAALp16wYHBwfLBtdEYmNjERUVhdDQUPTt2xfJyckoKytDdHS0pUNrNqWlpfjll1+k6YsXLyInJwcuLi7o0qWLBSNrHjNnzkRKSgq++uorODo6SuOXnJ2dYWtra+Homk9cXBxGjRqFLl264Pbt20hJSUFmZiZ2795t6dCajaOjY7WxS/qxnc06pqlZ70lqQY4dOyYOGzZMdHFxEW1sbEQfHx/x5ZdfFq9cuWLp0JrNxo0bRQBGf9qSqKgoo32wd+9eS4fWpN577z2xS5cuorW1tdi3b1/xwIEDlg6pWe3du9fo9x4VFWXp0JpFTb/7GzdutHRozer5558Xu3btKlpbW4sdO3YUn3jiCXHPnj2WDsviLHFbc5sbw0JEREQtT+sYkEBEREStGhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7/z9HhY5nYwKkDgAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "for ctrl_s in ctrl_str:\n", - " plt.hist(\n", - " [r for r, t in zip(logs[\"env/reward_dist\"], task_list) if t == ctrl_s],\n", - " density=True,\n", - " alpha=0.5,\n", - " label=ctrl_s,\n", - " )\n", - "plt.legend(loc=\"best\")\n", - "plt.title(\"reward distribution\")\n", - "plt.grid(True)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save model\n", - "Finally, we save the model to disk for later usage." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gpt2_model.save_pretrained(\"gpt2-imdb-ctrl\")\n", - "gpt2_tokenizer.save_pretrained(\"gpt2-imdb-ctrl\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "trl", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - }, - "vscode": { - "interpreter": { - "hash": "d2cfb53525227c89f8d14fa784301fa46c451cc9223d94ccce9e17956835eea2" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/notebooks/gpt2-sentiment.ipynb b/examples/notebooks/gpt2-sentiment.ipynb deleted file mode 100644 index 3c9b4a62385..00000000000 --- a/examples/notebooks/gpt2-sentiment.ipynb +++ /dev/null @@ -1,867 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tune GPT2 to generate positive reviews\n", - "> Optimise GPT2 to produce positive IMDB movie reviews using a BERT sentiment classifier as a reward function." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "
\n", - "\n", - "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", - "
\n", - "\n", - "\n", - "In this notebook we fine-tune GPT2 (small) to generate positive movie reviews based on the IMDB dataset. The model gets the start of a real review and is tasked to produce positive continuations. To reward positive continuations we use a BERT classifier to analyse the sentiment of the produced sentences and use the classifier's outputs as rewards signals for PPO training." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup experiment" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Import dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install transformers trl wandb" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from tqdm import tqdm\n", - "import pandas as pd\n", - "\n", - "tqdm.pandas()\n", - "\n", - "from transformers import pipeline, AutoTokenizer\n", - "from datasets import load_dataset\n", - "\n", - "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead\n", - "from trl.core import LengthSampler" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = PPOConfig(\n", - " model_name=\"lvwerra/gpt2-imdb\",\n", - " learning_rate=1.41e-5,\n", - " log_with=\"wandb\",\n", - ")\n", - "\n", - "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import wandb\n", - "\n", - "wandb.init()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/main/examples/legacy/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", - "https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load data and models" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load IMDB dataset\n", - "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def build_dataset(\n", - " config,\n", - " dataset_name=\"stanfordnlp/imdb\",\n", - " input_min_text_length=2,\n", - " input_max_text_length=8,\n", - "):\n", - " \"\"\"\n", - " Build dataset for training. This builds the dataset from `load_dataset`, one should\n", - " customize this function to train the model on its own dataset.\n", - "\n", - " Args:\n", - " config (`PPOConfig`):\n", - " The configuration of the PPO training.\n", - " dataset_name (`str`):\n", - " The name of the dataset to be loaded.\n", - " input_min_text_length (`int`, defaults to 5):\n", - " The minimum length of the input text.\n", - " input_max_text_length (`int`, defaults to 10):\n", - " The maximum length of the input text.\n", - "\n", - " Returns:\n", - " dataloader (`torch.utils.data.DataLoader`):\n", - " The dataloader for the dataset.\n", - " \"\"\"\n", - " tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - " # load imdb with datasets\n", - " ds = load_dataset(dataset_name, split=\"train\")\n", - " ds = ds.rename_columns({\"text\": \"review\"})\n", - " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", - "\n", - " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", - "\n", - " def tokenize(sample):\n", - " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", - " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", - " return sample\n", - "\n", - " ds = ds.map(tokenize, batched=False)\n", - " ds.set_format(type=\"torch\")\n", - " return ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = build_dataset(config)\n", - "\n", - "\n", - "def collator(data):\n", - " return dict((key, [d[key] for d in data]) for key in data[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load pre-trained GPT2 language models" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", - "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", - "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", - "\n", - "tokenizer.pad_token = tokenizer.eos_token" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initialize PPOTrainer\n", - "The `PPOTrainer` takes care of device placement and optimization later on:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ppo_trainer = PPOTrainer(\n", - " config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load BERT classifier\n", - "We load a BERT classifier fine-tuned on the IMDB dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = ppo_trainer.accelerator.device\n", - "if ppo_trainer.accelerator.num_processes == 1:\n", - " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", - "sentiment_pipe = pipeline(\n", - " \"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\", device=device\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n", - " {'label': 'POSITIVE', 'score': -2.726576328277588}]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text = \"this movie was really bad!!\"\n", - "sentiment_pipe(text, **sent_kwargs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'label': 'POSITIVE', 'score': 2.557040214538574},\n", - " {'label': 'NEGATIVE', 'score': -2.294790267944336}]" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text = \"this movie was really good!!\"\n", - "sentiment_pipe(text, **sent_kwargs)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Generation settings\n", - "For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gen_kwargs = {\n", - " \"min_length\": -1,\n", - " \"top_k\": 0.0,\n", - " \"top_p\": 1.0,\n", - " \"do_sample\": True,\n", - " \"pad_token_id\": tokenizer.eos_token_id,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optimize model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training loop" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The training loop consists of the following main steps:\n", - "1. Get the query responses from the policy network (GPT-2)\n", - "2. Get sentiments for query/responses from BERT\n", - "3. Optimize policy with PPO using the (query, response, reward) triplet\n", - "\n", - "**Training time**\n", - "\n", - "This step takes **~2h** on a V100 GPU with the above specified settings." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output_min_length = 4\n", - "output_max_length = 16\n", - "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", - "\n", - "\n", - "generation_kwargs = {\n", - " \"min_length\": -1,\n", - " \"top_k\": 0.0,\n", - " \"top_p\": 1.0,\n", - " \"do_sample\": True,\n", - " \"pad_token_id\": tokenizer.eos_token_id,\n", - "}\n", - "\n", - "\n", - "for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):\n", - " query_tensors = batch[\"input_ids\"]\n", - "\n", - " #### Get response from gpt2\n", - " response_tensors = []\n", - " for query in query_tensors:\n", - " gen_len = output_length_sampler()\n", - " generation_kwargs[\"max_new_tokens\"] = gen_len\n", - " query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()\n", - " response_len = len(query_response) - len(query)\n", - " response_tensors.append(query_response[-response_len:])\n", - " batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n", - "\n", - " #### Compute sentiment score\n", - " texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n", - " pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", - " positive_scores = [\n", - " item[\"score\"]\n", - " for output in pipe_outputs\n", - " for item in output\n", - " if item[\"label\"] == \"POSITIVE\"\n", - " ]\n", - " rewards = [torch.tensor(score) for score in positive_scores]\n", - "\n", - " #### Run PPO step\n", - " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", - " ppo_trainer.log_stats(stats, batch, rewards)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training progress\n", - "If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://wandb.ai/huggingface/trl/runs/w9l3110g).\n", - "\n", - "
\n", - "\n", - "

Figure: Reward mean and distribution evolution during training.

\n", - "
\n", - "\n", - "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", - "\n", - "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model inspection\n", - "Let's inspect some examples from the IMDB dataset. We can use `ref_model` to compare the tuned model `model` against the model before optimisation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
queryresponse (before)response (after)rewards (before)rewards (after)
0I rented Zero Day4 for my sister. To my surprise, the Wii caug.... It is a pleasure. It is a huge leap 68 years...1.7360682.423731
1The onlydistro of herspecial compliments is the0.1508520.190159
2I've read a fewnews reports about Mr. Mueller's activities b...novels and I never watch this. It has a reall...-1.4179622.831814
3This is the second British Rank film, and I wouldn't be surprised anymore if itthat I have enjoyed, achieving it in both the0.8358762.205628
4A classicclassic.<br /><br />And only this one will ha.... It's a movie with a fine cast. As the beginn...2.1130752.739168
5This has to be one of theworst with the differences being that for thebest thriller films I've seen in recent-2.7053392.730615
6Happy Go Lovely is a waste. Not only are extremelyof time, giving a-2.429504-2.934672
7Wow, I justcan't make fun of itfeek it! This show-2.201666-0.106085
8This movie makes several mistakes.Despite being a great comedic diversion it es...It's cool, wonderful - it held me into a very ...-1.2323802.707638
9Branagh and Fishburne, Drake is playedis a great show. Beautiful0.7768192.808996
10I might have given this movie arating of *11 when I heard that!), but it was...great performance. It was truly a great movie...0.2763802.743328
11Really, really badwith feel like there is no end to the. This movie is incredibly good, with the-2.639503-1.568827
12What another reviewer called lack ofjudgment, connecting into her own harsh obser...suspense. Rogers and Rooney rate this as exce...-1.0797072.696888
13This is simply onemore problem of Steveof the best choice-1.4454362.662699
14\"Perhaps we can arrange a meet-and-greet.<br /><br />Telegwith spent, classic music and dance, and come...0.2584791.876662
15Richard Willaims isnice enough; the little black guy plays quitebeautifully hands on in his own spin, and0.7965082.820259
\n", - "
" - ], - "text/plain": [ - " query \\\n", - "0 I rented Zero Day \n", - "1 The only \n", - "2 I've read a few \n", - "3 This is the second British Rank film \n", - "4 A classic \n", - "5 This has to be one of the \n", - "6 Happy Go Lovely is a waste \n", - "7 Wow, I just \n", - "8 This movie makes several mistakes. \n", - "9 Branagh and Fish \n", - "10 I might have given this movie a \n", - "11 Really, really bad \n", - "12 What another reviewer called lack of \n", - "13 This is simply one \n", - "14 \"Perhaps we can arrange a meet \n", - "15 Richard Willaims is \n", - "\n", - " response (before) \\\n", - "0 4 for my sister. To my surprise, the Wii caug... \n", - "1 distro of her \n", - "2 news reports about Mr. Mueller's activities b... \n", - "3 , and I wouldn't be surprised anymore if it \n", - "4 classic.

And only this one will ha... \n", - "5 worst with the differences being that for the \n", - "6 . Not only are extremely \n", - "7 can't make fun of it \n", - "8 Despite being a great comedic diversion it es... \n", - "9 burne, Drake is played \n", - "10 rating of *11 when I heard that!), but it was... \n", - "11 with feel like there is no end to the \n", - "12 judgment, connecting into her own harsh obser... \n", - "13 more problem of Steve \n", - "14 -and-greet.

Teleg \n", - "15 nice enough; the little black guy plays quite \n", - "\n", - " response (after) rewards (before) \\\n", - "0 . It is a pleasure. It is a huge leap 68 years... 1.736068 \n", - "1 special compliments is the 0.150852 \n", - "2 novels and I never watch this. It has a reall... -1.417962 \n", - "3 that I have enjoyed, achieving it in both the 0.835876 \n", - "4 . It's a movie with a fine cast. As the beginn... 2.113075 \n", - "5 best thriller films I've seen in recent -2.705339 \n", - "6 of time, giving a -2.429504 \n", - "7 feek it! This show -2.201666 \n", - "8 It's cool, wonderful - it held me into a very ... -1.232380 \n", - "9 is a great show. Beautiful 0.776819 \n", - "10 great performance. It was truly a great movie... 0.276380 \n", - "11 . This movie is incredibly good, with the -2.639503 \n", - "12 suspense. Rogers and Rooney rate this as exce... -1.079707 \n", - "13 of the best choice -1.445436 \n", - "14 with spent, classic music and dance, and come... 0.258479 \n", - "15 beautifully hands on in his own spin, and 0.796508 \n", - "\n", - " rewards (after) \n", - "0 2.423731 \n", - "1 0.190159 \n", - "2 2.831814 \n", - "3 2.205628 \n", - "4 2.739168 \n", - "5 2.730615 \n", - "6 -2.934672 \n", - "7 -0.106085 \n", - "8 2.707638 \n", - "9 2.808996 \n", - "10 2.743328 \n", - "11 -1.568827 \n", - "12 2.696888 \n", - "13 2.662699 \n", - "14 1.876662 \n", - "15 2.820259 " - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#### get a batch from the dataset\n", - "bs = 16\n", - "game_data = dict()\n", - "dataset.set_format(\"pandas\")\n", - "df_batch = dataset[:].sample(bs)\n", - "game_data[\"query\"] = df_batch[\"query\"].tolist()\n", - "query_tensors = df_batch[\"input_ids\"].tolist()\n", - "\n", - "response_tensors_ref, response_tensors = [], []\n", - "\n", - "#### get response from gpt2 and gpt2_ref\n", - "for i in range(bs):\n", - " query = torch.tensor(query_tensors[i]).to(device)\n", - "\n", - " gen_len = output_length_sampler()\n", - " query_response = ref_model.generate(\n", - " query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs\n", - " ).squeeze()\n", - " response_len = len(query_response) - len(query)\n", - " response_tensors_ref.append(query_response[-response_len:])\n", - "\n", - " query_response = model.generate(\n", - " query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs\n", - " ).squeeze()\n", - " response_len = len(query_response) - len(query)\n", - " response_tensors.append(query_response[-response_len:])\n", - "\n", - "#### decode responses\n", - "game_data[\"response (before)\"] = [\n", - " tokenizer.decode(response_tensors_ref[i]) for i in range(bs)\n", - "]\n", - "game_data[\"response (after)\"] = [\n", - " tokenizer.decode(response_tensors[i]) for i in range(bs)\n", - "]\n", - "\n", - "#### sentiment analysis of query/response pairs before/after\n", - "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n", - "pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", - "positive_scores = [\n", - " item[\"score\"]\n", - " for output in pipe_outputs\n", - " for item in output\n", - " if item[\"label\"] == \"POSITIVE\"\n", - "]\n", - "game_data[\"rewards (before)\"] = positive_scores\n", - "\n", - "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n", - "pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", - "positive_scores = [\n", - " item[\"score\"]\n", - " for output in pipe_outputs\n", - " for item in output\n", - " if item[\"label\"] == \"POSITIVE\"\n", - "]\n", - "game_data[\"rewards (after)\"] = positive_scores\n", - "\n", - "# store results in a dataframe\n", - "df_results = pd.DataFrame(game_data)\n", - "df_results" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Looking at the reward mean/median of the generated sequences we observe a significant difference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "mean:\n" - ] - }, - { - "data": { - "text/plain": [ - "rewards (before) -0.512965\n", - "rewards (after) 1.676750\n", - "dtype: float64" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "median:\n" - ] - }, - { - "data": { - "text/plain": [ - "rewards (before) -0.464427\n", - "rewards (after) 2.679794\n", - "dtype: float64" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "print(\"mean:\")\n", - "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n", - "print()\n", - "print(\"median:\")\n", - "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Save model\n", - "Finally, we save the model and push it to the Hugging Face for later usage." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('gpt2-imdb-pos-v2/tokenizer_config.json',\n", - " 'gpt2-imdb-pos-v2/special_tokens_map.json',\n", - " 'gpt2-imdb-pos-v2/vocab.json',\n", - " 'gpt2-imdb-pos-v2/merges.txt',\n", - " 'gpt2-imdb-pos-v2/added_tokens.json',\n", - " 'gpt2-imdb-pos-v2/tokenizer.json')" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n", - "tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - }, - "vscode": { - "interpreter": { - "hash": "4c8ff454cd947027f86954d72bf940c689a97dcc494eb53cfe4813862c6065fe" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}