Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-caption parquets crashing in multiple locations (Closes #1092) #1109

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 81 additions & 18 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from tqdm import tqdm
import queue
from math import sqrt
import pandas as pd
import numpy as np

logger = logging.getLogger("DataBackendFactory")
if should_log():
Expand All @@ -48,6 +50,68 @@ def info_log(message):
logger.info(message)


def check_column_values(column_data, column_name, parquet_path, fallback_caption_column=False):
# Determine if the column contains arrays or scalar values
non_null_values = column_data.dropna()
if non_null_values.empty:
# All values are null
raise ValueError(
f"Parquet file {parquet_path} contains only null values in the '{column_name}' column."
)

first_non_null = non_null_values.iloc[0]
if isinstance(first_non_null, (list, tuple, np.ndarray, pd.Series)):
# Column contains arrays
# Check for null arrays
if column_data.isnull().any() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains null arrays in the '{column_name}' column."
)

# Check for empty arrays
empty_arrays = column_data.apply(lambda x: len(x) == 0)
if empty_arrays.any() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains empty arrays in the '{column_name}' column."
)

# Check for null elements within arrays
null_elements_in_arrays = column_data.apply(
lambda arr: any(pd.isnull(s) for s in arr)
)
if null_elements_in_arrays.any() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains null values within arrays in the '{column_name}' column."
)

# Check for empty strings within arrays
empty_strings_in_arrays = column_data.apply(
lambda arr: any(s == "" for s in arr)
)
if empty_strings_in_arrays.all() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains only empty strings within arrays in the '{column_name}' column."
)

elif isinstance(first_non_null, str):
# Column contains scalar strings
# Check for null values
if column_data.isnull().any() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains null values in the '{column_name}' column."
)

# Check for empty strings
if (column_data == "").any() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains empty strings in the '{column_name}' column."
)
else:
raise TypeError(
f"Unsupported data type in column '{column_name}'. Expected strings or arrays of strings."
)


def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output = {"id": backend["id"], "config": {}}
if backend.get("dataset_type", None) == "text_embeds":
Expand Down Expand Up @@ -292,24 +356,23 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken
raise ValueError(
f"Parquet file {parquet_path} does not contain a column named '{filename_column}'."
)
# Check for null values
if df[caption_column].isnull().values.any() and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains null values in the '{caption_column}' column, but no fallback_caption_column was set."
)
if df[filename_column].isnull().values.any():
raise ValueError(
f"Parquet file {parquet_path} contains null values in the '{filename_column}' column."
)
# Check for empty strings
if (df[caption_column] == "").sum() > 0 and not fallback_caption_column:
raise ValueError(
f"Parquet file {parquet_path} contains empty strings in the '{caption_column}' column."
)
if (df[filename_column] == "").sum() > 0:
raise ValueError(
f"Parquet file {parquet_path} contains empty strings in the '{filename_column}' column."
)

# Apply the function to the caption_column.
check_column_values(
df[caption_column],
caption_column,
parquet_path,
fallback_caption_column=fallback_caption_column
)

# Apply the function to the filename_column.
check_column_values(
df[filename_column],
filename_column,
parquet_path,
fallback_caption_column=False # Always check filename_column
)

# Store the database in StateTracker
StateTracker.set_parquet_database(
backend["id"],
Expand Down
12 changes: 7 additions & 5 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,21 @@ def _extract_captions_to_fast_list(self):
if len(caption_column) > 0:
caption = [row[c] for c in caption_column]
else:
caption = row[caption_column]
caption = row.get(caption_column)
if isinstance(caption, (numpy.ndarray, pd.Series)):
caption = [str(item) for item in caption if item is not None]

if not caption and fallback_caption_column:
caption = row[fallback_caption_column]
if not caption:
if caption is None and fallback_caption_column:
caption = row.get(fallback_caption_column, None)
if caption is None or caption == "" or caption == []:
raise ValueError(
f"Could not locate caption for image {filename} in sampler_backend {self.id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(self.parquet_database)} entries."
)
if type(caption) == bytes:
caption = caption.decode("utf-8")
elif type(caption) == list:
caption = [c.strip() for c in caption if c.strip()]
if caption:
elif type(caption) == str:
caption = caption.strip()
captions[filename] = caption
return captions
Expand Down
30 changes: 18 additions & 12 deletions helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training import image_file_extensions

import numpy

try:
import pandas as pd
except ImportError:
raise ImportError("Pandas is required for the ParquetMetadataBackend.")

prompts = {
"alien_landscape": "Alien planet, strange rock formations, glowing plants, bizarre creatures, surreal atmosphere",
"alien_market": "Alien marketplace, bizarre creatures, exotic goods, vibrant colors, otherworldly atmosphere",
Expand Down Expand Up @@ -256,8 +263,10 @@ def prepare_instance_prompt_from_parquet(
)
if type(image_caption) == bytes:
image_caption = image_caption.decode("utf-8")
if image_caption:
if type(image_caption) == str:
image_caption = image_caption.strip()
if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series):
image_caption = [str(item).strip() for item in image_caption if item is not None]
if prepend_instance_prompt:
if type(image_caption) == list:
image_caption = [instance_prompt + " " + x for x in image_caption]
Expand Down Expand Up @@ -436,17 +445,14 @@ def get_all_captions(
data_backend=data_backend,
)
elif caption_strategy == "parquet":
try:
caption = PromptHandler.prepare_instance_prompt_from_parquet(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
sampler_backend_id=data_backend.id,
)
except:
continue
caption = PromptHandler.prepare_instance_prompt_from_parquet(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
sampler_backend_id=data_backend.id,
)
elif caption_strategy == "instanceprompt":
return [instance_prompt]
elif caption_strategy == "csv":
Expand Down
59 changes: 59 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest
import pandas as pd
from unittest.mock import patch, Mock, MagicMock
from PIL import Image
from pathlib import Path
from helpers.multiaspect.dataset import MultiAspectDataset
from helpers.metadata.backends.discovery import DiscoveryMetadataBackend
from helpers.data_backend.base import BaseDataBackend
from helpers.data_backend.factory import check_column_values


class TestMultiAspectDataset(unittest.TestCase):
Expand Down Expand Up @@ -82,5 +84,62 @@ def test_getitem_invalid_image(self):
self.dataset.__getitem__(self.image_metadata)


class TestDataBackendFactory(unittest.TestCase):
def test_all_null(self):
column_data = pd.Series([None, None, None])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains only null values", str(context.exception))

def test_arrays_with_nulls(self):
column_data = pd.Series([[1, 2], None, [3, 4]])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains null arrays", str(context.exception))

def test_empty_arrays(self):
column_data = pd.Series([[1, 2], [], [3, 4]])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains empty arrays", str(context.exception))

def test_null_elements_in_arrays(self):
column_data = pd.Series([[1, None], [2, 3], [3, 4]])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains null values within arrays", str(context.exception))

def test_empty_strings_in_arrays(self):
column_data = pd.Series([["", ""], ["", ""], ["", ""]])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains only empty strings within arrays", str(context.exception))

def test_scalar_strings_with_nulls(self):
column_data = pd.Series(["a", None, "b"])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains null values", str(context.exception))

def test_scalar_strings_with_empty(self):
column_data = pd.Series(["a", "", "b"])
with self.assertRaises(ValueError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("contains empty strings", str(context.exception))

def test_with_fallback_caption(self):
column_data = pd.Series([None, "", [None], [""]])
try:
check_column_values(column_data, "test_column", "test_file.parquet", fallback_caption_column=True)
except ValueError:
self.fail("check_column_values() raised ValueError unexpectedly with fallback_caption_column=True")

def test_invalid_data_type(self):
column_data = pd.Series([1, 2, 3])
with self.assertRaises(TypeError) as context:
check_column_values(column_data, "test_column", "test_file.parquet")
self.assertIn("Unsupported data type in column", str(context.exception))


if __name__ == "__main__":
unittest.main()
Loading