Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions .github/workflows/main.yml

This file was deleted.

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This process is illustrated in the sketch below:


<div style="text-align: center">
<img src="nbs/images/trl_overview.png" width="800">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>

Expand Down Expand Up @@ -94,7 +94,7 @@ train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
For a detailed example check out the notebook `04-gpt2-sentiment-ppo-training.ipynb`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:

<div style="text-align: center">
<img src="nbs/images/table_imdb_preview.png" width="800">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/table_imdb_preview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"metadata": {},
"source": [
"<div style=\"text-align: center\">\n",
"<img src='images/gpt2_bert_training.png' width='600'>\n",
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_bert_training.png' width='600'>\n",
"<p style=\"text-align: center;\"> <b>Figure:</b> Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face. </p>\n",
"</div>\n",
"\n",
Expand All @@ -38,7 +38,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
Expand All @@ -48,7 +52,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import torch\n",
Expand Down Expand Up @@ -79,7 +87,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"config = {\n",
Expand Down Expand Up @@ -115,7 +127,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
Expand All @@ -141,7 +157,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -203,7 +223,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stderr",
Expand All @@ -224,7 +248,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -255,7 +283,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"sent_kwargs = {\n",
Expand All @@ -277,7 +309,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
Expand All @@ -299,7 +335,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -342,7 +382,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name'])\n",
Expand All @@ -363,7 +407,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -397,7 +445,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"gpt2_model.to(device);\n",
Expand All @@ -421,7 +473,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"class LengthSampler:\n",
Expand All @@ -444,7 +500,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def tokenize(sample):\n",
Expand All @@ -466,7 +526,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"gen_kwargs = {\n",
Expand Down Expand Up @@ -496,7 +560,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def collater(data):\n",
Expand Down Expand Up @@ -529,7 +597,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config)\n",
Expand Down Expand Up @@ -584,7 +656,7 @@
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/lvwerra/trl-showcase/runs/1jtvxb1m/).\n",
"\n",
"<div style=\"text-align: center\">\n",
"<img src='images/gpt2_tuning_progress.png' width='800'>\n",
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_tuning_progress.png' width='800'>\n",
"<p style=\"text-align: center;\"> <b>Figure:</b> Reward mean and distribution evolution during training. </p>\n",
"</div>\n",
"\n",
Expand All @@ -604,7 +676,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -901,7 +977,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -960,7 +1040,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -1025,7 +1109,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": []
}
Expand Down
Loading