-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpt-vlm-run.py
495 lines (415 loc) · 20.9 KB
/
gpt-vlm-run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
import openai
import requests
from PIL import Image, ImageOps, ImageFilter
from io import BytesIO
import base64
import gradio as gr
import json
import os
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
API_KEY_FILE = "api_key.json"
SUPPORTED_FORMATS = {'png', 'jpeg', 'jpg', 'gif', 'webp', 'bmp', 'tiff', 'svg', 'ico', 'jfif', 'mpo', 'jpe'}
MAX_TOTAL_SIZE_MB = 20
MAX_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024
PRESETS_FILE = "presets.json"
OUTPUT_PATH_PRESETS_FILE = "output_path_presets.json"
# Token limits for different models
TOKEN_LIMITS = {
"gpt-3.5-turbo": {"TPM": 60000, "RPM": 500, "RPD": 10000, "TPD": 200000},
"gpt-4": {"TPM": 10000, "RPM": 500, "RPD": 10000, "TPD": 100000},
"gpt-4-turbo": {"TPM": 30000, "RPM": 500, "TPD": 90000},
"gpt-4-vision-preview": {"TPM": 10000, "RPM": 80, "RPD": 500, "TPD": 30000},
"gpt-4o": {"TPM": 30000, "RPM": 500, "TPD": 90000},
"gpt-4o-2024-05-13": {"TPM": 30000, "RPM": 500, "TPD": 90000}
}
FAILED_GENERATION_WORDS = ["error", "fail", "exception", "invalid", "unknown",
"unrecognized", "unsupported", "unavailable",
"unsuccessful", "failure", "incorrect",
"incomplete", "sorry", "cannot", "can't", "unable"]
# Load API key from file if available
def load_api_key():
if os.path.exists(API_KEY_FILE):
with open(API_KEY_FILE, 'r') as file:
data = json.load(file)
return data.get("api_key", "")
return ""
# Save API key to file
def save_api_key(api_key):
with open(API_KEY_FILE, 'w') as file:
json.dump({"api_key": api_key}, file)
# Load presets from file if available
def load_presets(file_path):
if os.path.exists(file_path):
with open(file_path, 'r') as file:
return json.load(file)
return []
# Save presets to file
def save_presets(presets, file_path):
with open(file_path, 'w') as file:
json.dump(presets, file, indent=4)
def is_supported_format(image_format):
return image_format.lower() in SUPPORTED_FORMATS
def resize_image(image, max_dimension):
"""Resize the image to fit within the max dimension while maintaining aspect ratio."""
image.thumbnail((max_dimension, max_dimension), Image.LANCZOS)
return image
def pad_to_square(image):
"""Pad the image with whitespace to make it a square."""
max_dimension = max(image.size)
new_image = Image.new("RGB", (max_dimension, max_dimension), (255, 255, 255))
new_image.paste(image, ((max_dimension - image.width) // 2, (max_dimension - image.height) // 2))
return new_image
def encode_image_to_base64(image, image_format):
buffered = BytesIO()
image.save(buffered, format=image_format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def check_total_size(images):
total_size = sum(len(img) for img in images)
if total_size > MAX_SIZE_BYTES:
raise ValueError(f"Total image size exceeds {MAX_TOTAL_SIZE_MB} MB limit.")
def process_images(images, max_image_dimension):
base64_images = []
image_filenames = []
for image_path in images:
img = Image.open(image_path)
if not is_supported_format(img.format):
raise ValueError(f"Unsupported image format: {img.format}. Supported formats are: {SUPPORTED_FORMATS}.")
img = resize_image(img, max_image_dimension)
img = pad_to_square(img)
image_format = img.format if img.format else 'JPEG'
base64_image = encode_image_to_base64(img, image_format)
base64_images.append(base64_image)
# Extract the image filename without extension for saving captions
filename = os.path.basename(image_path)
image_filenames.append(filename)
check_total_size(base64_images)
return base64_images, image_filenames
def get_token_limit(model):
"""Retrieve the token limit for the specified model."""
return TOKEN_LIMITS.get(model, {"TPM": 10000, "RPM": 80, "TPD": 30000})
def generate_image_captions(api_key, model, images, prompt, max_image_dimension, max_tokens=300):
openai.api_key = api_key
base64_images, image_filenames = process_images(images, max_image_dimension)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
captions = []
for img_str in base64_images:
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_str}"
}
}
]
}
],
"max_tokens": max_tokens
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
response_json = response.json() # Parse the response JSON
if 'choices' in response_json:
captions.append(response_json['choices'][0]['message']['content'].strip())
else:
error_message = response_json.get('error', {}).get('message', 'Unknown error')
print(f"Error in API response: {error_message}")
captions.append(f"Error: {error_message}")
print(response_json)
return captions, image_filenames
def contains_failed_generation(caption):
return any(word in caption.lower() for word in FAILED_GENERATION_WORDS)
def save_caption_to_file(caption, filename, output_path):
if not contains_failed_generation(caption):
for ext in SUPPORTED_FORMATS:
if ext in filename:
filename = filename.replace(f".{ext}", "")
break
output_file_path = os.path.join(output_path, f"{filename}.txt")
with open(output_file_path, 'w') as file:
file.write(caption)
print(f"Caption saved to: {output_file_path}")
else:
print(f"Failed to generate caption for: {filename}")
def batch_generate_captions(api_key, model, images, prompt, output_path, max_image_dimension, folder_path=None):
if folder_path:
images = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.split('.')[-1].lower() in SUPPORTED_FORMATS]
token_limit = get_token_limit(model)
max_tokens_per_minute = token_limit["TPM"]
max_tokens_per_request = min(token_limit["TPM"], 300) # 300 is a default example limit
# Process images in batches of 20
batch_size = 20
batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
all_captions = []
failed_images = []
for batch in batches:
captions, filenames = generate_image_captions(api_key, model, batch, prompt, max_image_dimension, max_tokens=max_tokens_per_request)
for caption, filename, image_path in zip(captions, filenames, batch):
if not contains_failed_generation(caption):
save_caption_to_file(caption, filename, output_path)
all_captions.append((caption, filename))
else:
failed_images.append(image_path)
print(f"Failed images: {failed_images}")
return all_captions, failed_images
def add_to_queue(new_images, queue):
if queue is None:
queue = []
if new_images:
new_image_paths = [image.name for image in new_images]
queue.extend(new_image_paths)
print(f"Added {len(new_image_paths)} images to queue. Queue length: {len(queue)}")
return queue
def process_queue(api_key, model, queue, prompt, output_path, max_image_dimension):
all_captions = []
failed_images = []
while queue:
batch = queue[:20]
queue = queue[20:]
print(f"Processing batch of {len(batch)} images. Remaining queue length: {len(queue)}")
captions, failed_batch = batch_generate_captions(api_key, model, batch, prompt, output_path, max_image_dimension)
all_captions.extend(captions)
failed_images.extend(failed_batch)
print(f"Processing Queue Completed. Failed images: {failed_images}")
return all_captions, failed_images, queue
def noise_level(image):
"""Calculate the noise level of the image."""
gray_image = image.convert("L")
noise = np.std(np.array(gray_image))
return noise
def color_distribution(image):
"""Get color distribution (mean and standard deviation of RGB channels)."""
pixels = np.array(image)
mean = np.mean(pixels, axis=(0, 1))
std = np.std(pixels, axis=(0, 1))
return mean, std
def sharpness(image):
"""Calculate the sharpness of the image."""
image = image.convert("L")
image = image.filter(ImageFilter.FIND_EDGES)
sharpness_value = np.mean(np.array(image))
return sharpness_value
def frequency_analysis(image):
"""Perform frequency analysis using Fourier Transform."""
gray_image = image.convert("L")
f_transform = np.fft.fft2(np.array(gray_image))
f_shift = np.fft.fftshift(f_transform)
magnitude_spectrum = 20 * np.log(np.abs(f_shift))
return np.mean(magnitude_spectrum)
def calculate_image_similarity(image1, image2):
"""Calculate similarity between two images using SSIM."""
image1 = image1.convert("L")
image2 = image2.convert("L")
image1 = np.array(image1)
image2 = np.array(image2)
similarity_index, _ = ssim(image1, image2, full=True)
return similarity_index
def calculate_caption_similarity(captions):
"""Calculate similarity between captions using TF-IDF and Cosine Similarity."""
vectorizer = TfidfVectorizer().fit_transform(captions)
vectors = vectorizer.toarray()
cosine_matrix = cosine_similarity(vectors)
return cosine_matrix
def analyze_directory(directory_path):
image_stats = []
word_counts = Counter()
captions = []
for root, _, files in os.walk(directory_path):
for file in files:
if file.split('.')[-1].lower() in SUPPORTED_FORMATS:
file_path = os.path.join(root, file)
img = Image.open(file_path)
# Image statistics
img_stats = {
"Filename": file,
"Width": img.width,
"Height": img.height,
"Format": img.format,
"Size (KB)": os.path.getsize(file_path) / 1024,
"Noise Level": noise_level(img),
"Color Mean": color_distribution(img)[0],
"Color Std": color_distribution(img)[1],
"Sharpness": sharpness(img),
"Frequency Mean": frequency_analysis(img)
}
image_stats.append(img_stats)
caption_file = os.path.join(root, file.split('.')[0] + '.txt')
if os.path.exists(caption_file):
with open(caption_file, 'r') as f:
caption = f.read()
words = caption.split()
word_counts.update(words)
captions.append(caption)
df_images = pd.DataFrame(image_stats)
df_words = pd.DataFrame(word_counts.items(), columns=['Word', 'Frequency']).sort_values(by='Frequency', ascending=False)
# Image similarity
if len(image_stats) > 1:
img_similarities = []
for i in range(len(image_stats)):
for j in range(i + 1, len(image_stats)):
img1_path = os.path.join(directory_path, image_stats[i]["Filename"])
img2_path = os.path.join(directory_path, image_stats[j]["Filename"])
img1 = Image.open(img1_path)
img2 = Image.open(img2_path)
similarity = calculate_image_similarity(img1, img2)
img_similarities.append({
"Image1": image_stats[i]["Filename"],
"Image2": image_stats[j]["Filename"],
"Similarity": similarity
})
df_img_similarity = pd.DataFrame(img_similarities)
else:
df_img_similarity = pd.DataFrame(columns=["Image1", "Image2", "Similarity"])
# Caption similarity
if len(captions) > 1:
caption_sim_matrix = calculate_caption_similarity(captions)
caption_similarities = []
for i in range(len(captions)):
for j in range(i + 1, len(captions)):
caption_similarities.append({
"Caption1": image_stats[i]["Filename"],
"Caption2": image_stats[j]["Filename"],
"Similarity": caption_sim_matrix[i][j]
})
df_caption_similarity = pd.DataFrame(caption_similarities)
else:
df_caption_similarity = pd.DataFrame(columns=["Caption1", "Caption2", "Similarity"])
return df_images, df_words, df_img_similarity, df_caption_similarity
# Load existing presets
presets = load_presets(PRESETS_FILE)
output_path_presets = load_presets(OUTPUT_PATH_PRESETS_FILE)
# Gradio Interface
queue = []
def single_image_mode(api_key, model, image, prompt, output_path, max_image_dimension):
captions, filenames = generate_image_captions(api_key, model, [image], prompt, max_image_dimension)
save_caption_to_file(captions[0], filenames[0], output_path)
return captions[0]
def batch_image_mode(api_key, model, images, prompt, output_path, max_image_dimension, folder_path=None):
new_images = images if not folder_path else [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.split('.')[-1].lower() in SUPPORTED_FORMATS]
global queue
queue = add_to_queue(new_images, queue)
captions, failed_images, queue = process_queue(api_key, model, queue, prompt, output_path, max_image_dimension)
return f"Captions saved automatically for all images. Failed images: {failed_images}", failed_images
def add_preset(preset_name, prompt):
presets.append({"name": preset_name, "prompt": prompt})
save_presets(presets, PRESETS_FILE)
return gr.update(choices=[preset["name"] for preset in presets])
def delete_preset(preset_name):
global presets
presets = [preset for preset in presets if preset["name"] != preset_name]
save_presets(presets, PRESETS_FILE)
return gr.update(choices=[preset["name"] for preset in presets])
def load_preset(preset_name):
for preset in presets:
if preset["name"] == preset_name:
return preset["prompt"]
return ""
def add_output_path_preset(preset_name, output_path):
output_path_presets.append({"name": preset_name, "output_path": output_path})
save_presets(output_path_presets, OUTPUT_PATH_PRESETS_FILE)
return gr.update(choices=[preset["name"] for preset in output_path_presets])
def delete_output_path_preset(preset_name):
global output_path_presets
output_path_presets = [preset for preset in output_path_presets if preset["name"] != preset_name]
save_presets(output_path_presets, OUTPUT_PATH_PRESETS_FILE)
return gr.update(choices=[preset["name"] for preset in output_path_presets])
def load_output_path_preset(preset_name):
for preset in output_path_presets:
if preset["name"] == preset_name:
return preset["output_path"]
return ""
def analyze_stats(directory_path):
df_images, df_words, df_img_similarity, df_caption_similarity = analyze_directory(directory_path)
# Image stats plot
fig, ax = plt.subplots()
df_images[['Width', 'Height']].plot(kind='hist', bins=30, alpha=0.5, ax=ax)
ax.set_title("Image Dimension Distribution")
ax.set_xlabel("Pixels")
ax.set_ylabel("Frequency")
plt.tight_layout()
plt.savefig("image_stats.png")
# Word frequency plot
fig, ax = plt.subplots()
df_words.head(20).plot(kind='bar', x='Word', y='Frequency', ax=ax)
ax.set_title("Top 20 Words Frequency")
ax.set_xlabel("Words")
ax.set_ylabel("Frequency")
plt.tight_layout()
plt.savefig("word_stats.png")
return df_images, df_words, df_img_similarity, df_caption_similarity
with gr.Blocks() as demo:
gr.Markdown("## Image Caption Generator with GPT-4")
api_key_input = gr.Textbox(label="API Key", value=load_api_key(), type="password")
save_api_key_button = gr.Button("Save API Key", variant="primary")
model_selection = gr.Dropdown(label="Model", choices=["gpt-4o", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"], value="gpt-4o")
max_image_dimension_slider = gr.Slider(label="Max Image Dimension", minimum=128, maximum=1024, step=64, value=512)
prompt_input = gr.Textbox(label="Prompt", lines=2)
output_path_input = gr.Textbox(label="Output Path")
with gr.Accordion("Presets", open=True):
with gr.Column():
with gr.Row():
preset_name_input = gr.Textbox(label="Preset Name")
output_path_preset_name_input = gr.Textbox(label="Output Path Preset Name")
with gr.Row():
presets_dropdown = gr.Dropdown(label="Presets", choices=[preset["name"] for preset in presets])
output_path_presets_dropdown = gr.Dropdown(label="Output Path Presets", choices=[preset["name"] for preset in output_path_presets])
with gr.Row():
with gr.Column():
load_preset_button = gr.Button("Load Prompt Preset", variant="primary")
load_output_path_preset_button = gr.Button("Load Output Path Preset", variant="primary")
with gr.Column():
delete_preset_button = gr.Button("Delete Prompt Preset", variant="primary")
delete_output_path_preset_button = gr.Button("Delete Output Path Preset", variant="primary")
with gr.Column():
add_preset_button = gr.Button("Add Prompt Preset", variant="primary")
add_output_path_preset_button = gr.Button("Add Output Path Preset", variant="primary")
add_preset_button.click(add_preset, [preset_name_input, prompt_input], presets_dropdown)
delete_preset_button.click(delete_preset, presets_dropdown, presets_dropdown)
load_preset_button.click(load_preset, presets_dropdown, prompt_input)
add_output_path_preset_button.click(add_output_path_preset, [output_path_preset_name_input, output_path_input], output_path_presets_dropdown)
delete_output_path_preset_button.click(delete_output_path_preset, output_path_presets_dropdown, output_path_presets_dropdown)
load_output_path_preset_button.click(load_output_path_preset, output_path_presets_dropdown, output_path_input)
with gr.Tabs():
with gr.TabItem("Single Image Mode"):
single_image_output = gr.Textbox(label="Generated Caption")
with gr.Row():
save_single_image_button = gr.Button("Save Caption", variant="primary")
single_image_button = gr.Button("Generate Caption", variant="primary")
image_input = gr.Image(type="filepath", label="Upload Image")
single_image_button.click(single_image_mode, [api_key_input, model_selection, image_input, prompt_input, output_path_input, max_image_dimension_slider], single_image_output)
save_single_image_button.click(save_caption_to_file, [single_image_output, image_input, output_path_input], outputs=[])
with gr.TabItem("Batch Image Mode"):
folder_path_input = gr.Textbox(label="Folder Path (optional)")
batch_image_output = gr.Textbox(label="Status")
batch_image_button = gr.Button("Generate Captions", variant="primary")
images_input = gr.Files(type="filepath", label="Upload Images", file_count="multiple")
add_to_queue_button = gr.Button("Add to Queue", variant="primary")
additional_images_input = gr.Files(type="filepath", label="Upload Additional Images", file_count="multiple")
add_to_queue_button.click(add_to_queue, [additional_images_input, gr.State(queue)], queue)
batch_image_button.click(batch_image_mode, [api_key_input, model_selection, images_input, prompt_input, output_path_input, max_image_dimension_slider, folder_path_input], [batch_image_output, images_input])
with gr.TabItem("Statistics"):
directory_input = gr.Textbox(label="Directory Path")
stats_button = gr.Button("Generate Stats", variant="primary")
image_stats_output = gr.Dataframe(label="Image Stats")
word_stats_output = gr.Dataframe(label="Word Stats")
img_similarity_output = gr.Dataframe(label="Image Similarity")
caption_similarity_output = gr.Dataframe(label="Caption Similarity")
stats_button.click(analyze_stats, directory_input, [image_stats_output, word_stats_output, img_similarity_output, caption_similarity_output])
save_api_key_button.click(save_api_key, inputs=[api_key_input], outputs=[])
demo.launch()