From ec00d09708ebd80b88709fb136ed0969d544e2fc Mon Sep 17 00:00:00 2001 From: peichins Date: Mon, 22 Apr 2024 16:26:26 +1000 Subject: [PATCH] updated batch entry point --- src/batch.py | 39 ++++++++++++++++++++++++++++----------- src/inference_parquet.py | 22 +++++++++++++++++++--- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/batch.py b/src/batch.py index 339a29e..ade761e 100644 --- a/src/batch.py +++ b/src/batch.py @@ -8,9 +8,13 @@ import yaml import csv from pathlib import Path +import time +import sys +import tqdm from src.embed_audio_slim import embed_file_and_save from src.config import load_config +from src.inference_parquet import classify_file_and_save #import train_linear_model #import inference_slim @@ -57,31 +61,44 @@ def batch(command, source_csv, start_row, end_row, config_file, overwrite_existi else: print(f"processing {item['source']}") - match command: - case "generate": - embed_file_and_save(item['source'], item['output'], config) - case "inference": - # inference_slim.analze - print("inference") - case _: - print("invalid command") + sys.stdout.flush() + + try: + match command: + case "generate": + embed_file_and_save(item['source'], item['output'], config) + case "classify": + # inference_slim.analze + classify_file_and_save(item['source'], item['output'], config) + case _: + print("invalid command") + except BlockingIOError as e: + print(f"IO error with file {item['source']}: {e}, retrying in 1 second...") + time.sleep(0.1) # Wait for before moving to the next file + continue # Skip to the next iteration + except Exception as e: + print(f"An error occurred with file {item['source']}: {e}") + continue + + def main (): """Just the arg parsing from command line""" - valid_commands = ('generate', 'inference') + valid_commands = ('generate', 'classify') parser = argparse.ArgumentParser() parser.add_argument("command", choices=list(valid_commands), help=" | ".join(valid_commands)) parser.add_argument("--source_csv", help="path to a csv that has the columns 'source' and 'output'") parser.add_argument("--start_row", default=None, help="which row on the csv to start from (zero index)") parser.add_argument("--end_row", help="last row in the csv to process (zero index)") - parser.add_argument("--config_file", default=None, help="path to the config file") + parser.add_argument("--config", default=None, help="path to the config file") parser.add_argument("--overwrite_existing", default=False, help="if true, will overwrite existing files, else will skip if exists") args = parser.parse_args() - batch(args.command, args.source_csv, int(args.start_row), int(args.end_row), args.config_file, args.overwrite_existing) + batch(args.command, args.source_csv, int(args.start_row), int(args.end_row), args.config, args.overwrite_existing) if __name__ == "__main__": + main() diff --git a/src/inference_parquet.py b/src/inference_parquet.py index f487807..8a64c56 100644 --- a/src/inference_parquet.py +++ b/src/inference_parquet.py @@ -7,7 +7,8 @@ from functools import partial import random from dataclasses import dataclass - +import sys +import time import keras from ml_collections import ConfigDict @@ -105,11 +106,15 @@ def process_folder(input_path, output_path, config): """ # slightly more efficient to load the classifier once rather than for each file + print(f"loading classifier {config.classifier}") classifier = load_classifier(config.classifier) # list of paths to the embeddings files relative to the embeddings_path + print("getting list of embeddings files") + sys.stdout.flush() embeddings_files_relative = [path.relative_to(input_path) for path in Path(input_path).rglob('*.parquet')] + print(f'found {len(embeddings_files_relative)} embeddings files') # dodgy parallel: shuffle and start script in a different process with skip_if_file_exists=True random.shuffle(embeddings_files_relative) @@ -119,12 +124,23 @@ def process_folder(input_path, output_path, config): for index, embedding_file in enumerate(tqdm.tqdm(embeddings_files_relative, desc="Processing")): file_output_path = output_path / embedding_file.with_suffix('.csv') + # print(f'skipping {embedding_file} to {file_output_path} because we are just testing') + # continue if config.skip_if_file_exists and file_output_path.exists(): print(f'skipping {embedding_file} as {file_output_path} already exists') continue #print(f'processing {index} of {len(embeddings_files_relative)}: {embedding_file}') - results = classify_embeddings_file(input_path / embedding_file, classifier) - save_classification_results(results, file_output_path) + + try: + results = classify_embeddings_file(input_path / embedding_file, classifier) + save_classification_results(results, file_output_path) + except BlockingIOError as e: + print(f"IO error with file {embedding_file}: {e}, retrying in 1 second...") + time.sleep(0.1) # Wait for before moving to the next file + continue # Skip to the next iteration + except Exception as e: + print(f"An error occurred with file {embedding_file}: {e}") + continue # Optionally handle other exceptions and move on print(f'finished processing {len(embeddings_files_relative)} embeddings files')