Skip to content

Commit 1849301

Browse files
initial version
0 parents  commit 1849301

File tree

4 files changed

+2708
-0
lines changed

4 files changed

+2708
-0
lines changed

README.md

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Flux training
2+
3+
I made this repo to be able to share with you the details of how I fine-tuned the Flux model. The main steps were made following this [tutorial](https://medium.com/@geronimo7/how-to-train-a-flux1-lora-for-1-dfd1800afce5).
4+
5+
## Hardware requirements
6+
In order to train you would need at least 24 GB VRAM, and currently there is support for single-GPU, so you should go for a SageMaker notebook running on a G5 or higher. For inference, currently you need a bit more, at least 28GB VRAM. I couldn't run inference on our ML-PT account, I ended up renting an instance on Vast.AI with 48GB VRAM. For storage, you would need at least 100GB of storage.
7+
8+
## Training steps
9+
10+
### Step 1: Clone repo and install dependencies
11+
The tutorial I followed is based on the Ostris' [AI-Toolkit](https://github.com/ostris/ai-toolkit). You first begin by cloning that repo and installing it's dependencies. The tutorial uses /workspace as folder, you could choose whatever suits you.
12+
13+
```
14+
!cd /workspace
15+
!git clone https://github.com/ostris/ai-toolkit.git
16+
!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt
17+
```
18+
19+
### Step 2: Upload images and generate captions
20+
Then, you need to upload a folder with your images and captions. If you don't have captions, you can generate them automatically using the script in the `image-caption.ipynb` notebook. The code inside takes the folder with your images, generates a caption for every image and stores it on a .txt file with the same name. Keep in mind that you may need to adapt it to your specific folder structure.
21+
22+
### Step 3: Log into HuggingFace
23+
Access to the FLUX1-dev model is gated, so you first have to accept their terms. Log into your Hugging Face account (or create one) and accept their terms: [FLUX1-dev repository](https://huggingface.co/black-forest-labs/FLUX.1-dev)
24+
25+
Next, generate a Hugging Face API token on your account and log in:
26+
```
27+
!huggingface-cli login --token hf_XXXXTOKENXXXX
28+
```
29+
30+
### Step 4: Define training parameters
31+
32+
On the first cell of the `train-flux.ipynb` you would need to define:
33+
* INPUT_FOLDER: where your images and captions are stored),
34+
* OUTPUT_FOLDER: where to store results like samples and weights,
35+
* TRIGGER_WORD: the name/word for the object or subject you are finetuning on.
36+
37+
and a few other training parameters upon which you can play on, like after how many steps you would like to save the weights or produce sample images to measure progress. For your first time I would leave them as is, and adjust on subsequent training runs if you feel the need for it.
38+
39+
### Step 5: Define job parameter dictionary
40+
41+
That's what the second cell of the notebook is doing. You can dig deeper into each parameter, my advice is that if you are training with limited VRAM, that you leave uncommented the `low_vram` parameter. If you have VRAM to spare, comment it so the training takes much less time.
42+
43+
### Step 6: Final step - run the training job
44+
45+
Finally, run the training job using the last cell of the notebook. The actual training time will vary depending on how many images you used, your GPU VRAM, how many steps you defined, etc. In my case, using 24GB VRAM and `low_vram` mode, it took around 4:30 hs to run 2250 training epochs.
46+
47+
## Inference
48+
49+
Finally, you can use the `flux-lora-img-gen-results.ipynb` notebook to use your fine-tuned model and generate cool images with it. There are several parameters to consider when running the inference:
50+
51+
* prompt and negative prompt: the actual textual description of the image you want to generate. Apparently Flux relies on plain text prompts, different to Stable diffusion which had a bunch of flags and parameters. For inspiration you could visit [PromptHero](https://prompthero.com/flux-prompts?__cf_chl_tk=nKmeQBc9IU6dIH9o44wP3ak3HplrZ71Rfq_jM1gC8k4-1727291842-0.0.1.1-7956), even though I found it to be quite biased towards suggestive images 😒. Also, I left a couple interesting prompts on the inference notebook.
52+
* num_inference_steps: how many inference passes the model does before returning the image. You would have to find a number that gives you the best number, I found that between 20 and 40 I had the best results.
53+
* width
54+
* height
55+
56+
## Extra info
57+
58+
Bojan Jakimovski shared with me this other repository for fine-tuning Flux that looks even easier [FluxGym](https://github.com/cocktailpeanut/fluxgym). I believe is from the guy that made the AI-Browser [Pinokio](https://pinokio.computer/). If anyone tries it out, please let me know and we could add your experience to this repo.

flux-lora-img-gen-results.ipynb

+712
Large diffs are not rendered by default.

image-caption.ipynb

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "14e6b1e1-d567-41c4-98b1-78accfdaf73e",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"!python -m pip install --upgrade pip wheel setuptools"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"id": "2302b9f3-e521-4cbe-99f4-b7d3982b2aff",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"!pip install torch"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "6f522e74-7814-4d70-81da-bd4ae447cf19",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"!FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"id": "38e6031b-adb6-4807-a237-9b679cff6b51",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"!pip install transformers timm"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": null,
46+
"id": "5de01ad4-42ee-4c82-97db-777260a94163",
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"from transformers import AutoProcessor, AutoModelForCausalLM\n",
51+
"from PIL import Image\n",
52+
"import requests\n",
53+
"import copy\n",
54+
"\n",
55+
"model_id = 'microsoft/Florence-2-large'\n",
56+
"model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()\n",
57+
"processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)\n",
58+
"\n",
59+
"def run_example(task_prompt, text_input=None):\n",
60+
" if text_input is None:\n",
61+
" prompt = task_prompt\n",
62+
" else:\n",
63+
" prompt = task_prompt + text_input\n",
64+
"\n",
65+
" inputs = processor(text=prompt, images=image, return_tensors=\"pt\")\n",
66+
" generated_ids = model.generate(\n",
67+
" input_ids=inputs[\"input_ids\"].cuda(),\n",
68+
" pixel_values=inputs[\"pixel_values\"].cuda(),\n",
69+
" max_new_tokens=1024,\n",
70+
" early_stopping=False,\n",
71+
" do_sample=False,\n",
72+
" num_beams=3,\n",
73+
" )\n",
74+
" generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
75+
" parsed_answer = processor.post_process_generation(\n",
76+
" generated_text,\n",
77+
" task=task_prompt,\n",
78+
" image_size=(image.width, image.height)\n",
79+
" )\n",
80+
"\n",
81+
" return parsed_answer"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": 8,
87+
"id": "7f57030b-a1dd-4cca-923a-eab7bbc19d34",
88+
"metadata": {},
89+
"outputs": [
90+
{
91+
"name": "stdout",
92+
"output_type": "stream",
93+
"text": [
94+
"{'<MORE_DETAILED_CAPTION>': 'The image shows a young man standing on a sandy beach with a lake and mountains in the background. He is wearing a grey t-shirt, black shorts, and sunglasses, and has a backpack slung over his shoulder. He has a red hat in his left hand and is holding a pair of sunglasses in his right hand. The man is looking up at the sky with a slight smile on his face. The lake is calm and the water is a light blue color. There are trees and mountains visible in the distance. The sky is clear and blue.'}\n"
95+
]
96+
}
97+
],
98+
"source": [
99+
"image = Image.open(\"img16.jpg\").convert(\"RGB\")\n",
100+
"\n",
101+
"task_prompt = \"<MORE_DETAILED_CAPTION>\"\n",
102+
"answer = run_example(task_prompt=task_prompt)\n",
103+
"\n",
104+
"print(answer)"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 9,
110+
"id": "420c7f51-3150-4b3e-852b-2f6525624d88",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"import os"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"id": "4105d2e1-6ae0-4b66-a393-cc2520903843",
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"os.listdir('./images')"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": 20,
130+
"id": "2239f850-dfb5-46fe-a02d-531a9cb12564",
131+
"metadata": {},
132+
"outputs": [
133+
{
134+
"name": "stdout",
135+
"output_type": "stream",
136+
"text": [
137+
"Captioning image: /images/img9\n",
138+
"Captioning image: /images/img8\n",
139+
"Captioning image: /images/img5\n",
140+
"Captioning image: /images/img4\n",
141+
"Captioning image: /images/img6\n",
142+
"Captioning image: /images/img7\n",
143+
"Captioning image: /images/img3\n",
144+
"Captioning image: /images/img2\n",
145+
"Captioning image: /images/img1\n",
146+
"Captioning image: /images/img16\n",
147+
"Captioning image: /images/img14\n",
148+
"Captioning image: /images/img15\n",
149+
"Captioning image: /images/img11\n",
150+
"Captioning image: /images/img10\n",
151+
"Captioning image: /images/img12\n",
152+
"Captioning image: /images/img13\n"
153+
]
154+
}
155+
],
156+
"source": [
157+
"folder = './images'\n",
158+
"\n",
159+
"list_of_img = os.listdir(folder)\n",
160+
"\n",
161+
"for img in list_of_img:\n",
162+
" if img.endswith('.jpg'):\n",
163+
" file_path = (folder+'/'+img).split('.')[1]\n",
164+
" print(f'Captioning image: {file_path}')\n",
165+
" image_path = '.'+file_path+'.jpg'\n",
166+
" image = Image.open(image_path).convert(\"RGB\")\n",
167+
" task_prompt = \"<MORE_DETAILED_CAPTION>\"\n",
168+
" answer = run_example(task_prompt=task_prompt)\n",
169+
" text_path = '.'+file_path+'.txt'\n",
170+
" with open(text_path, 'w') as f:\n",
171+
" f.write(answer['<MORE_DETAILED_CAPTION>'])\n",
172+
" "
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": null,
178+
"id": "eac60bdd-ab37-4f62-98c8-d068bcd55aad",
179+
"metadata": {},
180+
"outputs": [],
181+
"source": []
182+
}
183+
],
184+
"metadata": {
185+
"kernelspec": {
186+
"display_name": "Python 3 (ipykernel)",
187+
"language": "python",
188+
"name": "python3"
189+
},
190+
"language_info": {
191+
"codemirror_mode": {
192+
"name": "ipython",
193+
"version": 3
194+
},
195+
"file_extension": ".py",
196+
"mimetype": "text/x-python",
197+
"name": "python",
198+
"nbconvert_exporter": "python",
199+
"pygments_lexer": "ipython3",
200+
"version": "3.10.14"
201+
}
202+
},
203+
"nbformat": 4,
204+
"nbformat_minor": 5
205+
}

0 commit comments

Comments
 (0)