Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable caption deduplication as it prevents multigpu caching; add warning for sd3 using wrong VAE; cleanly terminate and restart batch text embed writing thread #1111

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,21 @@ def discover_all_files(self):

def save_to_cache(self, filename, embeddings):
"""Add write requests to the queue instead of writing directly."""
self.process_write_batches = True
if not self.batch_write_thread.is_alive():
logger.debug("Restarting background write thread.")
# Start the thread again.
self.process_write_batches = True
self.batch_write_thread = Thread(target=self.batch_write_embeddings)
self.batch_write_thread.start()
self.write_queue.put((embeddings, filename))
logger.debug(
f"save_to_cache called for {filename}, write queue has {self.write_queue.qsize()} items, and the write thread's status: {self.batch_write_thread.is_alive()}"
)

def batch_write_embeddings(self):
"""Process write requests in batches."""
batch = []
written_elements = 0
while True:
try:
# Block until an item is available or timeout occurs
Expand All @@ -226,14 +233,25 @@ def batch_write_embeddings(self):
while (
not self.write_queue.empty() and len(batch) < self.write_batch_size
):
logger.debug("Retrieving more items from the queue.")
items = self.write_queue.get_nowait()
batch.append(items)
logger.debug(f"Batch now contains {len(batch)} items.")

self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug("Processed batch write.")
written_elements += len(batch)

except queue.Empty:
# Timeout occurred, no items were ready
if not self.process_write_batches:
if len(batch) > 0:
self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug(f"Exiting batch write thread, no more work to do after writing {written_elements} elements")
break
logger.debug(f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}")
pass
except Exception:
logger.exception("An error occurred while writing embeddings to disk.")
Expand All @@ -242,6 +260,7 @@ def batch_write_embeddings(self):
def process_write_batch(self, batch):
"""Write a batch of embeddings to the cache."""
logger.debug(f"Writing {len(batch)} items to disk")
logger.debug(f"Batch: {batch}")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(self.data_backend.torch_save, *args) for args in batch
Expand Down Expand Up @@ -1301,7 +1320,7 @@ def compute_embeddings_for_sd3_prompts(
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
self.debug_log(f"Encoding prompt: {prompt}")
self.debug_log(f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}")
prompt_embeds, pooled_prompt_embeds = self.encode_sd3_prompt(
self.text_encoders,
self.tokenizers,
Expand All @@ -1314,7 +1333,7 @@ def compute_embeddings_for_sd3_prompts(
),
)
logger.debug(
f"SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
f"Filename {filename} SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
)
add_text_embeds = pooled_prompt_embeds
# StabilityAI say not to zero them out.
Expand Down
2 changes: 1 addition & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,7 +2079,7 @@ def parse_cmdline_args(input_args=None):

if (
args.pretrained_vae_model_name_or_path is not None
and args.model_family in ["legacy", "flux"]
and args.model_family in ["legacy", "flux", "sd3"]
and "sdxl" in args.pretrained_vae_model_name_or_path
and "deepfloyd" not in args.model_type
):
Expand Down
2 changes: 1 addition & 1 deletion helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
accelerator=accelerator,
cache_dir=init_backend.get("cache_dir", args.cache_dir_text),
model_type=StateTracker.get_model_family(),
write_batch_size=backend.get("write_batch_size", 1),
write_batch_size=backend.get("write_batch_size", args.write_batch_size),
)
init_backend["text_embed_cache"].set_webhook_handler(
StateTracker.get_webhook_handler()
Expand Down
3 changes: 2 additions & 1 deletion helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ def get_all_captions(
captions.extend(caption)

# Deduplicate captions
captions = list(set(captions))
# TODO: Investigate why this prevents captions from processing on multigpu systems.
# captions = list(set(captions))

return captions

Expand Down
Loading