From 8eedd1f31e5ef000f17817873821e7484af0c53b Mon Sep 17 00:00:00 2001 From: jabberjabberjabber <75267439+jabberjabberjabber@users.noreply.github.com> Date: Sun, 15 Dec 2024 19:10:50 -0500 Subject: [PATCH] Add files via upload Fix config bug. --- chunkify-gui.py | 957 +++++++++++++++++++++++------------------------ chunkify.py | 972 ++++++++++++++++++++++++------------------------ 2 files changed, 973 insertions(+), 956 deletions(-) diff --git a/chunkify-gui.py b/chunkify-gui.py index 58e0812..3bdcaa0 100644 --- a/chunkify-gui.py +++ b/chunkify-gui.py @@ -1,478 +1,479 @@ -import sys -import os -import json -import time - -from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, - QHBoxLayout, QPushButton, QRadioButton, QButtonGroup, - QFileDialog, QListWidget, QLabel, QTextEdit, QLineEdit, - QGroupBox, QPlainTextEdit, QDialog, QComboBox, QSpinBox, - QDoubleSpinBox, QMenuBar, QMenu, QAction, QSizePolicy) -from PyQt5.QtCore import Qt, QThread, pyqtSignal -from pathlib import Path -from chunkify import LLMConfig, LLMProcessor, check_api - -## -## GUI written mostly by Claude Sonnet 3.5 -## - -class ProcessingThread(QThread): - progress_signal = pyqtSignal(str) - finished_signal = pyqtSignal(list) - - def __init__(self, config, task, instruction, files, selected_template=None): - super().__init__() - self.config = config - self.task = task - self.instruction = instruction - self.files = files - self.selected_template = selected_template - def run(self): - try: - self.processor = LLMProcessor(self.config, self.task) - - # Override the monitor function for the GUI - def gui_monitor(): - generating = False - last_result = "" - payload = {'genkey': self.processor.genkey} - while not self.processor.generated: - result = self.processor._call_api("check", payload) - if not result: - time.sleep(2) - continue - if result != last_result: # Only update if the text has changed - last_result = result - # Send a clear signal before new text - self.progress_signal.emit("<>") - self.progress_signal.emit(f"{result}") - time.sleep(1) - - # Replace the monitor function - self.processor._monitor_generation = gui_monitor - - results = [] - - for file_path in self.files: - self.progress_signal.emit(f"Processing {file_path}...") - content, metadata = self.processor._get_content(file_path) - - if self.task == "custom": - responses = self.processor.process_in_chunks(self.instruction, content) - else: - responses = self.processor.route_task(self.task, content) - - # Create output filename - path = Path(file_path) - output_path = path.parent / f"{path.stem}_processed{path.suffix}" - - # Write output - with open(output_path, 'w', encoding='utf-8') as f: - f.write(f"File: {metadata.get('resourceName', 'Unknown')}\n") - f.write(f"Type: {metadata.get('Content-Type', 'Unknown')}\n") - f.write(f"Encoding: {metadata.get('Content-Encoding', 'Unknown')}\n") - f.write(f"Length: {metadata.get('Content-Length', 'Unknown')}\n\n") - for response in responses: - f.write(f"{response}\n\n") - - results.append((file_path, output_path)) - - self.finished_signal.emit(results) - - except Exception as e: - self.progress_signal.emit(f"Error: {str(e)}") - -class ChunkerGUI(QMainWindow): - def __init__(self): - super().__init__() - self.config_file = './chunkify_config.json' - self.config = self.load_config() - self.selected_template = None - self.initUI() - self.api_ready = False - - # Start API check timer - from PyQt5.QtCore import QTimer - self.api_timer = QTimer() - self.api_timer.timeout.connect(self.check_api) - self.api_timer.start(2000) # Check every 2 seconds - #self.check_api() - def check_api(self): - try: - #while not self.api_ready: - result = check_api(self.config.api_url) - if result: - self.api_ready = True - self.process_button.setEnabled(True) - self.output_text.appendPlainText("API is ready - you can now process files.") - self.api_timer.stop() - else: - if not self.api_ready: # Only show loading message if not yet ready - self.process_button.setEnabled(False) - self.output_text.setPlainText("Waiting for API to become available...\nPlease wait for model weights to download.") - #time.sleep(2) - except Exception as e: - if not self.api_ready: - self.process_button.setEnabled(False) - self.output_text.setPlainText(f"Waiting for API...\n") - - def initUI(self): - self.setWindowTitle('Text Processing GUI') - self.setGeometry(100, 100, 800, 600) - - # Main widget and layout - main_widget = QWidget() - self.setCentralWidget(main_widget) - layout = QVBoxLayout() - - # File selection area - file_group = QGroupBox("File Selection") - file_layout = QVBoxLayout() - - self.file_list = QListWidget() - file_buttons = QHBoxLayout() - - add_button = QPushButton('Add Files') - add_button.clicked.connect(self.add_files) - remove_button = QPushButton('Remove Selected') - remove_button.clicked.connect(self.remove_files) - - file_buttons.addWidget(add_button) - file_buttons.addWidget(remove_button) - - file_layout.addLayout(file_buttons) - file_layout.addWidget(self.file_list) - file_group.setLayout(file_layout) - self.file_list.setFixedHeight(100) # Adjust this value as needed - self.file_list.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) - self.file_list.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) - - # For the output text, we can use size policies to make it expand - - # Task selection area - task_group = QGroupBox("Task Selection") - task_layout = QVBoxLayout() - - self.task_group = QButtonGroup() - tasks = ['summary', 'translate', 'distill', 'correct'] - - for i, task in enumerate(tasks): - radio = QRadioButton(task.capitalize()) - self.task_group.addButton(radio, i) - task_layout.addWidget(radio) - if task == 'summary': - radio.setChecked(True) - - - task_group.setLayout(task_layout) - - # Output area - output_group = QGroupBox("Output") - output_layout = QVBoxLayout() - - self.output_text = QPlainTextEdit() - self.output_text.setStyleSheet(""" - QPlainTextEdit { - background-color: black; - color: white; - font-family: Consolas, Monaco, monospace; - } - """) - self.output_text.setReadOnly(True) - self.output_text.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - output_layout.addWidget(self.output_text) - output_group.setLayout(output_layout) - - # Process button - self.process_button = QPushButton('Process Files') - self.process_button.clicked.connect(self.process_files) - - # Add everything to main layout - layout.addWidget(file_group) - layout.addWidget(task_group) - layout.addWidget(output_group) - layout.addWidget(self.process_button) - - main_widget.setLayout(layout) - file_group.setMaximumHeight(200) # Adjust this value as needed - - # Keep task group compact - task_group.setMaximumHeight(150) # Adjust this value as needed - - # Make output group expand - output_group.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - - # Keep process button from expanding - self.process_button.setFixedHeight(30) # Optional, for consistent button height - # Menu bar - menubar = self.menuBar() - settings_menu = menubar.addMenu('Settings') - - config_action = QAction('Configuration', self) - config_action.triggered.connect(self.show_config_dialog) - settings_menu.addAction(config_action) - - # Initialize processing thread as None - self.processing_thread = None - - def add_files(self): - files, _ = QFileDialog.getOpenFileNames( - self, - "Select files to process", - "", - "All Files (*.*)" - ) - for file in files: - self.file_list.addItem(file) - - def remove_files(self): - for item in self.file_list.selectedItems(): - self.file_list.takeItem(self.file_list.row(item)) - - def process_files(self): - if self.processing_thread and self.processing_thread.isRunning(): - return - - files = [self.file_list.item(i).text() - for i in range(self.file_list.count())] - - if not files: - self.output_text.appendPlainText("Error: No files selected") - return - - # Clear the output box before starting - self.output_text.clear() - - # Get selected task - task_id = self.task_group.checkedId() - tasks = ['summary', 'translate', 'distill', 'correct'] - task = tasks[task_id] - instruction = "" - - config = self.config - - self.processing_thread = ProcessingThread( - config=config, - task=task, - instruction=instruction, - files=files, - selected_template=self.selected_template - ) - - # Update progress handler to check for clear signal - def handle_progress(msg): - if msg == "<>": - self.output_text.clear() - else: - self.output_text.appendPlainText(msg) - - self.processing_thread.progress_signal.connect(handle_progress) - self.processing_thread.finished_signal.connect(self.processing_finished) - self.processing_thread.start() - - self.process_button.setEnabled(False) - - - def processing_finished(self, results): - self.output_text.appendPlainText("\nProcessing completed!") - for input_file, output_file in results: - self.output_text.appendPlainText( - f"\nProcessed {input_file}\nOutput saved to {output_file}" - ) - self.process_button.setEnabled(True) - - def show_config_dialog(self): - dialog = ConfigDialog(self.config, self) - if dialog.exec_() == QDialog.Accepted: - # Existing config updates... - self.config.api_url = dialog.api_url_input.text() - self.config.api_password = dialog.api_password_input.text() - self.config.temp = dialog.temp_input.value() - self.config.rep_pen = dialog.rep_pen_input.value() - self.config.top_k = dialog.top_k_input.value() - self.config.top_p = dialog.top_p_input.value() - self.config.min_p = dialog.min_p_input.value() - self.selected_template = dialog.template_combo.currentText() - - # Update translation language in config - self.config.translation_language = dialog.translation_language_input.text() - - # Save config to file - self.save_config() - - # Immediately check API with new settings - self.api_ready = False - self.check_api() - - # If we have an active processing thread, update its config and refresh instructions - if self.processing_thread and self.processing_thread.isRunning(): - self.processing_thread.config = self.config - if hasattr(self.processing_thread, 'processor'): - self.processing_thread.processor.update_config(self.config) - - def load_config(self): - """Load configuration from JSON file.""" - try: - if os.path.exists(self.config_file): - with open(self.config_file) as f: - config_data = json.load(f) - - - # Load template if it exists - self.selected_template = config_data.pop('selected_template', None) - - return LLMConfig(**config_data) - else: - # Return default config - return LLMConfig( - templates_directory="./templates", - api_url="http://localhost:5001", - api_password="" - ) - except Exception as e: - print(f"Error loading config: {e}") - return LLMConfig( - templates_directory="./templates", - api_url="http://localhost:5001", - api_password="" - ) - - def save_config(self): - """Save current configuration to JSON file.""" - try: - config_data = { - 'templates_directory': self.config.templates_directory, - 'api_url': self.config.api_url, - 'api_password': self.config.api_password, - 'temp': self.config.temp, - 'rep_pen': self.config.rep_pen, - 'top_k': self.config.top_k, - 'top_p': self.config.top_p, - 'min_p': self.config.min_p, - 'selected_template': self.selected_template, - 'translation_language': self.translation_language - } - - with open(self.config_file, 'w') as f: - json.dump(config_data, f, indent=4) - - except Exception as e: - print(f"Error saving config: {e}") - - -class ConfigDialog(QDialog): - def __init__(self, config, parent=None): - super().__init__(parent) - self.config = config - - self.initUI() - - def initUI(self): - self.setWindowTitle('Configuration') - self.setModal(True) - layout = QVBoxLayout() - - # API Settings - api_group = QGroupBox("API Settings") - api_layout = QVBoxLayout() - - self.api_url_input = QLineEdit(self.config.api_url) - api_layout.addWidget(QLabel("API URL:")) - api_layout.addWidget(self.api_url_input) - - self.api_password_input = QLineEdit(self.config.api_password) - self.api_password_input.setEchoMode(QLineEdit.Password) - api_layout.addWidget(QLabel("API Password:")) - api_layout.addWidget(self.api_password_input) - - self.translation_language_input = QLineEdit(self.config.translation_language) - #self.translation_language_input.setEchoMode(QLineEdit.translation_language) - api_layout.addWidget(QLabel("Translate to language:")) - api_layout.addWidget(self.translation_language_input) - - api_group.setLayout(api_layout) - - # Sampler Settings - sampler_group = QGroupBox("Sampler Settings") - sampler_layout = QVBoxLayout() - - self.temp_input = QDoubleSpinBox() - self.temp_input.setRange(0, 2) - self.temp_input.setSingleStep(0.1) - self.temp_input.setValue(self.config.temp) - sampler_layout.addWidget(QLabel("Temperature:")) - sampler_layout.addWidget(self.temp_input) - - self.rep_pen_input = QDoubleSpinBox() - self.rep_pen_input.setRange(0, 10) - self.rep_pen_input.setSingleStep(0.1) - self.rep_pen_input.setValue(self.config.rep_pen) - sampler_layout.addWidget(QLabel("Repetition Penalty:")) - sampler_layout.addWidget(self.rep_pen_input) - - self.top_k_input = QSpinBox() - self.top_k_input.setRange(0, 100) - self.top_k_input.setValue(self.config.top_k) - sampler_layout.addWidget(QLabel("Top K:")) - sampler_layout.addWidget(self.top_k_input) - - self.top_p_input = QDoubleSpinBox() - self.top_p_input.setRange(0, 1) - self.top_p_input.setSingleStep(0.1) - self.top_p_input.setValue(self.config.top_p) - sampler_layout.addWidget(QLabel("Top P:")) - sampler_layout.addWidget(self.top_p_input) - - self.min_p_input = QDoubleSpinBox() - self.min_p_input.setRange(0, 1) - self.min_p_input.setSingleStep(0.01) - self.min_p_input.setValue(self.config.min_p) - sampler_layout.addWidget(QLabel("Min P:")) - sampler_layout.addWidget(self.min_p_input) - - sampler_group.setLayout(sampler_layout) - - - # Template Selection - template_group = QGroupBox("Template Selection") - template_layout = QVBoxLayout() - - self.template_combo = QComboBox() - # Load available templates - template_path = Path(self.config.templates_directory) - if template_path.exists(): - templates = [f.stem for f in template_path.glob('*.json')] - self.template_combo.addItems(['Auto'] + templates) - else: - self.template_combo.addItem('Auto') - - template_layout.addWidget(QLabel("Model Template:")) - template_layout.addWidget(self.template_combo) - - template_group.setLayout(template_layout) - - # Buttons - button_layout = QHBoxLayout() - save_button = QPushButton('Save') - save_button.clicked.connect(self.accept) - cancel_button = QPushButton('Cancel') - cancel_button.clicked.connect(self.reject) - button_layout.addWidget(save_button) - button_layout.addWidget(cancel_button) - - # Add all groups to main layout - layout.addWidget(api_group) - layout.addWidget(sampler_group) - layout.addWidget(template_group) - layout.addLayout(button_layout) - - self.setLayout(layout) - -def main(): - app = QApplication(sys.argv) - gui = ChunkerGUI() - gui.show() - sys.exit(app.exec_()) - -if __name__ == '__main__': - main() +import sys +import os +import json +import time + +from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, + QHBoxLayout, QPushButton, QRadioButton, QButtonGroup, + QFileDialog, QListWidget, QLabel, QTextEdit, QLineEdit, + QGroupBox, QPlainTextEdit, QDialog, QComboBox, QSpinBox, + QDoubleSpinBox, QMenuBar, QMenu, QAction, QSizePolicy) +from PyQt5.QtCore import Qt, QThread, pyqtSignal +from pathlib import Path +from chunkify import LLMConfig, LLMProcessor, check_api + +## +## GUI written mostly by Claude Sonnet 3.5 +## + +class ProcessingThread(QThread): + progress_signal = pyqtSignal(str) + finished_signal = pyqtSignal(list) + + def __init__(self, config, task, instruction, files, selected_template=None): + super().__init__() + self.config = config + self.task = task + self.instruction = instruction + self.files = files + self.selected_template = selected_template + def run(self): + try: + self.processor = LLMProcessor(self.config, self.task) + + # Override the monitor function for the GUI + def gui_monitor(): + generating = False + last_result = "" + payload = {'genkey': self.processor.genkey} + while not self.processor.generated: + result = self.processor._call_api("check", payload) + if not result: + time.sleep(2) + continue + if result != last_result: # Only update if the text has changed + last_result = result + # Send a clear signal before new text + self.progress_signal.emit("<>") + self.progress_signal.emit(f"{result}") + time.sleep(1) + + # Replace the monitor function + self.processor._monitor_generation = gui_monitor + + results = [] + + for file_path in self.files: + self.progress_signal.emit(f"Processing {file_path}...") + content, metadata = self.processor._get_content(file_path) + + if self.task == "custom": + responses = self.processor.process_in_chunks(self.instruction, content) + else: + responses = self.processor.route_task(self.task, content) + + # Create output filename + path = Path(file_path) + output_path = path.parent / f"{path.stem}_processed{path.suffix}" + + # Write output + with open(output_path, 'w', encoding='utf-8') as f: + f.write(f"File: {metadata.get('resourceName', 'Unknown')}\n") + f.write(f"Type: {metadata.get('Content-Type', 'Unknown')}\n") + f.write(f"Encoding: {metadata.get('Content-Encoding', 'Unknown')}\n") + f.write(f"Length: {metadata.get('Content-Length', 'Unknown')}\n\n") + for response in responses: + f.write(f"{response}\n\n") + + results.append((file_path, output_path)) + + self.finished_signal.emit(results) + + except Exception as e: + self.progress_signal.emit(f"Error: {str(e)}") + +class ChunkerGUI(QMainWindow): + def __init__(self): + super().__init__() + self.config_file = './chunkify_config.json' + self.config = self.load_config() + self.selected_template = None + self.initUI() + self.api_ready = False + + # Start API check timer + from PyQt5.QtCore import QTimer + self.api_timer = QTimer() + self.api_timer.timeout.connect(self.check_api) + self.api_timer.start(2000) # Check every 2 seconds + + #self.check_api() + def check_api(self): + try: + #while not self.api_ready: + result = check_api(self.config.api_url) + if result: + self.api_ready = True + self.process_button.setEnabled(True) + self.output_text.appendPlainText("API is ready - you can now process files.") + self.api_timer.stop() + else: + if not self.api_ready: # Only show loading message if not yet ready + self.process_button.setEnabled(False) + self.output_text.setPlainText("Waiting for API to become available...") + #time.sleep(2) + except Exception as e: + if not self.api_ready: + self.process_button.setEnabled(False) + self.output_text.setPlainText(f"Waiting for API...\n") + + def initUI(self): + self.setWindowTitle('Text Processing GUI') + self.setGeometry(100, 100, 800, 600) + + # Main widget and layout + main_widget = QWidget() + self.setCentralWidget(main_widget) + layout = QVBoxLayout() + + # File selection area + file_group = QGroupBox("File Selection") + file_layout = QVBoxLayout() + + self.file_list = QListWidget() + file_buttons = QHBoxLayout() + + add_button = QPushButton('Add Files') + add_button.clicked.connect(self.add_files) + remove_button = QPushButton('Remove Selected') + remove_button.clicked.connect(self.remove_files) + + file_buttons.addWidget(add_button) + file_buttons.addWidget(remove_button) + + file_layout.addLayout(file_buttons) + file_layout.addWidget(self.file_list) + file_group.setLayout(file_layout) + self.file_list.setFixedHeight(100) # Adjust this value as needed + self.file_list.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) + self.file_list.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) + + # For the output text, we can use size policies to make it expand + + # Task selection area + task_group = QGroupBox("Task Selection") + task_layout = QVBoxLayout() + + self.task_group = QButtonGroup() + tasks = ['summary', 'translate', 'distill', 'correct'] + + for i, task in enumerate(tasks): + radio = QRadioButton(task.capitalize()) + self.task_group.addButton(radio, i) + task_layout.addWidget(radio) + if task == 'summary': + radio.setChecked(True) + + + task_group.setLayout(task_layout) + + # Output area + output_group = QGroupBox("Output") + output_layout = QVBoxLayout() + + self.output_text = QPlainTextEdit() + self.output_text.setStyleSheet(""" + QPlainTextEdit { + background-color: black; + color: white; + font-family: Consolas, Monaco, monospace; + } + """) + self.output_text.setReadOnly(True) + self.output_text.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + output_layout.addWidget(self.output_text) + output_group.setLayout(output_layout) + + # Process button + self.process_button = QPushButton('Process Files') + self.process_button.clicked.connect(self.process_files) + + # Add everything to main layout + layout.addWidget(file_group) + layout.addWidget(task_group) + layout.addWidget(output_group) + layout.addWidget(self.process_button) + + main_widget.setLayout(layout) + file_group.setMaximumHeight(200) # Adjust this value as needed + + # Keep task group compact + task_group.setMaximumHeight(150) # Adjust this value as needed + + # Make output group expand + output_group.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + + # Keep process button from expanding + self.process_button.setFixedHeight(30) # Optional, for consistent button height + # Menu bar + menubar = self.menuBar() + settings_menu = menubar.addMenu('Settings') + + config_action = QAction('Configuration', self) + config_action.triggered.connect(self.show_config_dialog) + settings_menu.addAction(config_action) + + # Initialize processing thread as None + self.processing_thread = None + + def add_files(self): + files, _ = QFileDialog.getOpenFileNames( + self, + "Select files to process", + "", + "All Files (*.*)" + ) + for file in files: + self.file_list.addItem(file) + + def remove_files(self): + for item in self.file_list.selectedItems(): + self.file_list.takeItem(self.file_list.row(item)) + + def process_files(self): + if self.processing_thread and self.processing_thread.isRunning(): + return + + files = [self.file_list.item(i).text() + for i in range(self.file_list.count())] + + if not files: + self.output_text.appendPlainText("Error: No files selected") + return + + # Clear the output box before starting + self.output_text.clear() + + # Get selected task + task_id = self.task_group.checkedId() + tasks = ['summary', 'translate', 'distill', 'correct'] + task = tasks[task_id] + instruction = "" + + config = self.config + + self.processing_thread = ProcessingThread( + config=config, + task=task, + instruction=instruction, + files=files, + selected_template=self.selected_template + ) + + # Update progress handler to check for clear signal + def handle_progress(msg): + if msg == "<>": + self.output_text.clear() + else: + self.output_text.appendPlainText(msg) + + self.processing_thread.progress_signal.connect(handle_progress) + self.processing_thread.finished_signal.connect(self.processing_finished) + self.processing_thread.start() + + self.process_button.setEnabled(False) + + + def processing_finished(self, results): + self.output_text.appendPlainText("\nProcessing completed!") + for input_file, output_file in results: + self.output_text.appendPlainText( + f"\nProcessed {input_file}\nOutput saved to {output_file}" + ) + self.process_button.setEnabled(True) + + def show_config_dialog(self): + dialog = ConfigDialog(self.config, self) + if dialog.exec_() == QDialog.Accepted: + # Existing config updates... + self.config.api_url = dialog.api_url_input.text() + self.config.api_password = dialog.api_password_input.text() + self.config.temp = dialog.temp_input.value() + self.config.rep_pen = dialog.rep_pen_input.value() + self.config.top_k = dialog.top_k_input.value() + self.config.top_p = dialog.top_p_input.value() + self.config.min_p = dialog.min_p_input.value() + self.selected_template = dialog.template_combo.currentText() + + # Update translation language in config + self.config.translation_language = dialog.translation_language_input.text() + + # Save config to file + self.save_config() + + # Immediately check API with new settings + self.api_ready = False + self.check_api() + + # If we have an active processing thread, update its config and refresh instructions + if self.processing_thread and self.processing_thread.isRunning(): + self.processing_thread.config = self.config + if hasattr(self.processing_thread, 'processor'): + self.processing_thread.processor.update_config(self.config) + + def load_config(self): + """Load configuration from JSON file.""" + try: + if os.path.exists(self.config_file): + with open(self.config_file) as f: + config_data = json.load(f) + + + # Load template if it exists + self.selected_template = config_data.pop('selected_template', None) + + return LLMConfig(**config_data) + else: + # Return default config + return LLMConfig( + templates_directory="./templates", + api_url="http://localhost:5001", + api_password="" + ) + except Exception as e: + print(f"Error loading config: {e}") + return LLMConfig( + templates_directory="./templates", + api_url="http://localhost:5001", + api_password="" + ) + + def save_config(self): + """Save current configuration to JSON file.""" + try: + config_data = { + 'templates_directory': self.config.templates_directory, + 'api_url': self.config.api_url, + 'api_password': self.config.api_password, + 'temp': self.config.temp, + 'rep_pen': self.config.rep_pen, + 'top_k': self.config.top_k, + 'top_p': self.config.top_p, + 'min_p': self.config.min_p, + 'selected_template': self.selected_template, + 'translation_language': self.config.translation_language + } + + with open(self.config_file, 'w') as f: + json.dump(config_data, f, indent=4) + + except Exception as e: + print(f"Error saving config: {e}") + + +class ConfigDialog(QDialog): + def __init__(self, config, parent=None): + super().__init__(parent) + self.config = config + + self.initUI() + + def initUI(self): + self.setWindowTitle('Configuration') + self.setModal(True) + layout = QVBoxLayout() + + # API Settings + api_group = QGroupBox("API Settings") + api_layout = QVBoxLayout() + + self.api_url_input = QLineEdit(self.config.api_url) + api_layout.addWidget(QLabel("API URL:")) + api_layout.addWidget(self.api_url_input) + + self.api_password_input = QLineEdit(self.config.api_password) + self.api_password_input.setEchoMode(QLineEdit.Password) + api_layout.addWidget(QLabel("API Password:")) + api_layout.addWidget(self.api_password_input) + + self.translation_language_input = QLineEdit(self.config.translation_language) + #self.translation_language_input.setEchoMode(QLineEdit.translation_language) + api_layout.addWidget(QLabel("Translate to language:")) + api_layout.addWidget(self.translation_language_input) + + api_group.setLayout(api_layout) + + # Sampler Settings + sampler_group = QGroupBox("Sampler Settings") + sampler_layout = QVBoxLayout() + + self.temp_input = QDoubleSpinBox() + self.temp_input.setRange(0, 2) + self.temp_input.setSingleStep(0.1) + self.temp_input.setValue(self.config.temp) + sampler_layout.addWidget(QLabel("Temperature:")) + sampler_layout.addWidget(self.temp_input) + + self.rep_pen_input = QDoubleSpinBox() + self.rep_pen_input.setRange(0, 10) + self.rep_pen_input.setSingleStep(0.1) + self.rep_pen_input.setValue(self.config.rep_pen) + sampler_layout.addWidget(QLabel("Repetition Penalty:")) + sampler_layout.addWidget(self.rep_pen_input) + + self.top_k_input = QSpinBox() + self.top_k_input.setRange(0, 100) + self.top_k_input.setValue(self.config.top_k) + sampler_layout.addWidget(QLabel("Top K:")) + sampler_layout.addWidget(self.top_k_input) + + self.top_p_input = QDoubleSpinBox() + self.top_p_input.setRange(0, 1) + self.top_p_input.setSingleStep(0.1) + self.top_p_input.setValue(self.config.top_p) + sampler_layout.addWidget(QLabel("Top P:")) + sampler_layout.addWidget(self.top_p_input) + + self.min_p_input = QDoubleSpinBox() + self.min_p_input.setRange(0, 1) + self.min_p_input.setSingleStep(0.01) + self.min_p_input.setValue(self.config.min_p) + sampler_layout.addWidget(QLabel("Min P:")) + sampler_layout.addWidget(self.min_p_input) + + sampler_group.setLayout(sampler_layout) + + + # Template Selection + template_group = QGroupBox("Template Selection") + template_layout = QVBoxLayout() + + self.template_combo = QComboBox() + # Load available templates + template_path = Path(self.config.templates_directory) + if template_path.exists(): + templates = [f.stem for f in template_path.glob('*.json')] + self.template_combo.addItems(['Auto'] + templates) + else: + self.template_combo.addItem('Auto') + + template_layout.addWidget(QLabel("Model Template:")) + template_layout.addWidget(self.template_combo) + + template_group.setLayout(template_layout) + + # Buttons + button_layout = QHBoxLayout() + save_button = QPushButton('Save') + save_button.clicked.connect(self.accept) + cancel_button = QPushButton('Cancel') + cancel_button.clicked.connect(self.reject) + button_layout.addWidget(save_button) + button_layout.addWidget(cancel_button) + + # Add all groups to main layout + layout.addWidget(api_group) + layout.addWidget(sampler_group) + layout.addWidget(template_group) + layout.addLayout(button_layout) + + self.setLayout(layout) + +def main(): + app = QApplication(sys.argv) + gui = ChunkerGUI() + gui.show() + sys.exit(app.exec_()) + +if __name__ == '__main__': + main() diff --git a/chunkify.py b/chunkify.py index b6b8274..66e2cdb 100644 --- a/chunkify.py +++ b/chunkify.py @@ -1,478 +1,494 @@ -import os -import re -import json -import random -import requests -from pathlib import Path -from dataclasses import dataclass -from typing import Dict, List, Optional, Union -from urllib.parse import urlparse -from chunker_regex import chunk_regex -import threading -import time -from extractous import Extractor -import sys - -@dataclass -class LLMConfig: - """ Configuration for LLM processing. - """ - templates_directory: str - api_url: str - api_password: str - translation_language: str - text_completion: bool = False - gen_count: int = 500 #not used - temp: float = 0.2 - rep_pen: float = 1 - min_p: float = 0.02 - top_k: int = 0 - top_p: int = 1 - - - - @classmethod - def from_json(cls, path: str): - """ Load configuration from JSON file. - Expects a JSON object with the same field names as the class. - """ - with open(path) as f: - config_dict = json.load(f) - return cls(**config_dict) - -class LLMProcessor: - def __init__(self, config, task): - """ Initialize the LLM processor with given configuration. - """ - self.config = config - self.api_function_urls = { - "tokencount": "/api/extra/tokencount", - "interrogate": "/api/v1/generate", - "max_context_length": "/api/extra/true_max_context_length", - "check": "/api/extra/generate/check", - "abort": "/api/extra/abort", - "version": "/api/extra/version", - "model": "/api/v1/model", - "generate": "/api/v1/generate", - } - self._update_instructions() - self.templates_directory = config.templates_directory - self.api_url = config.api_url - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {config.api_password}", - } - self.genkey = self._create_genkey() - self.templates = self._get_templates() - self.model = self._get_model() - self.max_context = self._get_max_context_length() - self.generated = False - self.system_instruction = "You are a helpful assistant." - self.task = task - self.max_chunk = int((self.max_context // 2) *.9) # give room for template - self.max_length = self.max_context // 2 - - def _update_instructions(self): - """Update instructions based on current config""" - self.summary_instruction = "Extract the key points, themes and actions from the text succinctly without developing any conclusions or commentary." - self.translate_instruction = f"Translate the entire document into {self.config.translation_language}. Maintain linguistic flourish and authorial style as much as possible. Write the full contents without condensing the writing or modernizing the language." - self.distill_instruction = "Rewrite the text to be as concise as possible without losing meaning." - self.correct_instruction = "Correct any grammar, spelling, style, or format errors in the text. Do not alter the text or otherwise change the meaning or style." - - def update_config(self, new_config): - """Update config and refresh instructions""" - self.config = new_config - self._update_instructions() - self.templates_directory = config.templates_directory - self.api_url = config.api_url - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {config.api_password}", - } - self.genkey = self._create_genkey() - self.templates = self._get_templates() - self.model = self._get_model() - self.max_context = self._get_max_context_length() - self.generated = False - self.system_instruction = "You are a helpful assistant." - self.task = task - self.max_chunk = int((self.max_context // 2) *.9) # give room for template - self.max_length = self.max_context // 2 - - def _get_templates(self): - """ Look in the templates directory and load JSON template files. - Falls back to Alpaca format if no valid templates found. - """ - templates = {} - alpaca_template = { - "name": ["alpaca"], - "akas": [], - "system_start": "### System:", - "system_end": "\n", - "user_start": "### Human:", - "user_end": "\n", - "assistant_start": "### Assistant:", - "assistant_end": "\n" - } - try: - template_path = Path(self.templates_directory) - if not template_path.exists(): - return {"alpaca": alpaca_template} - for file in template_path.glob('*.json'): - try: - with open(file) as f: - template = json.load(f) - required_fields = [ - "system_start", "system_end", - "user_start", "user_end", - "assistant_start", "assistant_end" - ] - if all(field in template for field in required_fields): - base_name = file.stem - if "akas" not in template: - template["akas"] = [] - if "name" not in template: - template["name"] = [base_name] - templates[base_name] = template - except (json.JSONDecodeError, KeyError) as e: - print(f"Error loading template {file}: {str(e)}") - continue - if not templates: - return {"alpaca": alpaca_template} - return templates - except Exception as e: - print(f"Error loading templates directory: {str(e)}") - return {"alpaca": alpaca_template} - - def _get_model(self): - """ Queries Kobold for current model name and finds matching template. - Prefers exact matches, then version matches, then base model matches. - Exits the script if no match found. - Overly complicated. - """ - if self.config.text_completion: - print("Using text completion mode") - return { - "name": ["Completion"], - "user": "", - "assistant": "", - "system": None, - } - model_name = self._call_api("model") - if not model_name: - print("Could not get model name from API, exiting script") - sys.exit(1) - print(f"Kobold reports model: {model_name}") - - def normalize(s): - """ Remove special chars and lowercase for matching. - """ - return re.sub(r"[^a-z0-9]", "", s.lower()) - - model_name_normalized = normalize(model_name) - best_match = None - best_match_length = 0 - best_match_version = 0 - for template in self.templates.values(): - names_to_check = template.get("name", []) - if isinstance(names_to_check, str): - names_to_check = [names_to_check] - names_to_check.extend(template.get("akas", [])) - for name in names_to_check: - normalized_name = normalize(name) - if normalized_name in model_name_normalized: - version_match = re.search(r'(\d+)(?:\.(\d+))?', name) - current_version = float(f"{version_match.group(1)}.{version_match.group(2) or '0'}") if version_match else 0 - name_length = len(normalized_name) - if current_version > best_match_version or \ - (current_version == best_match_version and name_length > best_match_length): - best_match = template - best_match_length = name_length - best_match_version = current_version - if best_match: - print(f"Selected template: {best_match.get('name', ['Unknown'])[0]}") - return best_match - print(f"No version-specific template found, trying base model match...") - for template in self.templates.values(): - names_to_check = template.get("name", []) - if isinstance(names_to_check, str): - names_to_check = [names_to_check] - names_to_check.extend(template.get("akas", [])) - for name in names_to_check: - normalized_name = normalize(name) - base_name = re.sub(r'\d+(?:\.\d+)?', '', normalized_name) - if base_name in model_name_normalized: - name_length = len(base_name) - if name_length > best_match_length: - best_match = template - best_match_length = name_length - if best_match: - print(f"Selected base template: {best_match.get('name', ['Unknown'])[0]}") - return best_match - print("No matching template found, exiting script") - sys.exit(1) - - def _call_api(self, api_function, payload=None): - """ Call the Kobold API. - Some API calls are POSTs and some are GETs. - """ - if api_function not in self.api_function_urls: - raise ValueError(f"Invalid API function: {api_function}") - url = f"{self.api_url}{self.api_function_urls[api_function]}" - try: - - if api_function in ["tokencount", "generate", "check", "interrogate", "abort"]: - response = requests.post(url, json=payload, headers=self.headers) - result = response.json() - if api_function == "tokencount": - return int(result.get("value")) - elif api_function == "abort": - return result.get("success") - else: - return result["results"][0].get("text") - else: - response = requests.get(url, json=payload, headers=self.headers) - result = response.json() - if resulted := result.get("result", None): - return resulted - else: - return int(result.get("value", None)) - except requests.RequestException as e: - print(f"Error calling API: {str(e)}") - return None - - def _get_initial_chunk(self, content): - """ We are chunking based on natural break points. - Only works well for Germanic and Romance languages. - Ideally content is in markdown format. - """ - total_tokens = self._get_token_count(content) - print(f"Content tokens to chunk: {total_tokens}") - if total_tokens < self.max_chunk: - return content - matches = chunk_regex.finditer(content) - current_size = 0 - chunks = [] - for match in matches: - chunk = match.group(0) - chunk_size = self._get_token_count(chunk) - if current_size + chunk_size > self.max_chunk: - if not chunks: - chunks.append(chunk) - break - chunks.append(chunk) - current_size += chunk_size - return ''.join(chunks) - - def compose_prompt(self, instruction="", content=""): - """ Create the prompt that gets sent to the LLM and specify samplers. - """ - prompt = self.get_prompt(instruction, content) - payload = { - "prompt": prompt, - "max_length": self.max_length, - "genkey": self.genkey, - "top_p": self.config.top_p, - "top_k": self.config.top_k, - "temp": self.config.temp, - "rep_pen": self.config.rep_pen, - "min_p": self.config.min_p, - } - return payload - - def get_prompt(self, instruction="", content=""): - """ Create a prompt to send to the LLM using the instruct template - or basic text completion. - """ - if not self.model: - raise ValueError("No model template loaded") - if self.model["name"] == ["Completion"]: - return f"{content}".strip() - user_text = f"{content}{instruction}" - if not user_text: - raise ValueError("No user text provided (both instruction and content are empty)") - prompt_parts = [] - if self.model.get("system") is not None: - prompt_parts.extend([ - self.model["system_start"], - self.model["system_instruction"], - self.model["system_end"] - ]) - prompt_parts.extend([ - self.model["user_start"], - user_text, - self.model["user_end"], - self.model["assistant_start"] - ]) - return "".join(prompt_parts) - - def route_task(self, task="summary", content=""): - """ Send to appropriate function. - """ - if task in ["correct", "translate", "distill", "summary"]: - instruction = getattr(self, f"{task}_instruction") - responses = self.process_in_chunks(instruction, content) - return responses - else: - raise ValueError(f"Unknown task: {task}") - - def process_in_chunks(self, instruction="", content=""): - """ Process the content into chunks. - """ - chunks = [] - remaining = content - while remaining: - chunk = self._get_initial_chunk(remaining) - chunk_len = len(chunk) - if chunk_len == 0: - print("Warning: Got zero-length chunk") - break - chunks.append(chunk) - remaining = remaining[len(chunk):].strip() - responses = [] - total_chunks = len(chunks) - print("Starting chunk processing...") - for i, chunk in enumerate(chunks, 1): - chunk_tokens = self._get_token_count(chunk) - print(f"Chunk {i} of {total_chunks}, Size: {chunk_tokens}\n") - time.sleep(2) - response = self.generate_with_status(self.compose_prompt( - instruction=instruction, - content=chunk - )) - if response: - responses.append(response) - return responses - - def generate_with_status(self, prompt): - """ Threads generation so that we can stream the output onto the - console otherwise we stare at a blank screen. - """ - self.generated = False - monitor = threading.Thread( - target=self._monitor_generation, - daemon=True - ) - monitor.start() - try: - result = self._call_api("generate", prompt) - self.generated = True - monitor.join() - return result - except Exception as e: - print(f"Generation error: {e}") - return None - - def _monitor_generation(self): - """ Write generation onto the terminal as it is created. - """ - generating = False - payload = { - 'genkey': self.genkey - } - while not self.generated: - result = self._call_api("check", payload) - if not result: - time.sleep(2) - continue - time.sleep(1) - clear_console() - print(f"{result}") - - @staticmethod - def _create_genkey(): - """ Create a unique generation key. - Prevents kobold from returning your generation to another query. - """ - return f"KCPP{''.join(str(random.randint(0, 9)) for _ in range(4))}" - - def _get_max_context_length(self): - """ Get the maximum context length from the API. - """ - max_context = self._call_api("max_context_length") - print(f"Model has maximum context length of: {max_context}") - return max_context - - def _get_token_count(self, content): - """ Get the token count for a piece of content. - """ - payload = {"prompt": content, "genkey": self.genkey} - return self._call_api("tokencount", payload) - - def _get_content(self, content): - """ Read text from a file to chunk. - """ - extractor = Extractor() - result, metadata = extractor.extract_file_to_string(content) - return result, metadata - -def check_api(api_url): - """ See if the API is ready - """ - url = f"{api_url}/api/v1/info/version/" - if requests.get(url, json="", headers=""): - return True - return False - -def write_output(output_path, task, responses, metadata): - """ Write the task response to a file. - """ - try: - with open(output_path, 'w', encoding='utf-8') as f: - f.write(f"File: {metadata.get('resourceName', 'Unknown')}\n") - f.write(f"Type: {metadata.get('Content-Type', 'Unknown')}\n") - f.write(f"Encoding: {metadata.get('Content-Encoding', 'Unknown')}\n") - f.write(f"Length: {metadata.get('Content-Length', 'Unknown')}\n\n") - for response in responses: - f.write(f"{response}\n\n") - print(f"\nOutput written to: {output_path}") - except Exception as e: - print(f"Error writing output file: {str(e)}") - -def clear_console(): - """ Clears the screen so the output can refresh. - """ - command = 'clear' - if os.name in ('nt', 'dos'): - command = 'cls' - os.system(command) - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="LLM Processor for Kobold API") - parser.add_argument("--config", type=str, help="Path to JSON config file") - parser.add_argument("--content", type=str, help="Content to process") - parser.add_argument("--api-url", type=str, default="http://localhost:5001", help="URL for the LLM API") - parser.add_argument("--api-password", type=str, default="", help="Password for the LLM API") - parser.add_argument("--templates", type=str, default="./templates", help="Directory for instruct templates") - parser.add_argument("--task", type=str, default="summary", help="Task: summary, translate, distill, correct") - parser.add_argument("--file", type=str, default="output.txt", help="Output to file path") - args = parser.parse_args() - try: - if args.config: - config = LLMConfig.from_json(args.config) - else: - config = LLMConfig( - templates_directory=args.templates, - api_url=args.api_url, - api_password=args.api_password, - translation_language="English" - ) - task = args.task.lower() - processor = LLMProcessor(config, task) - content, metadata = processor._get_content(args.content) - file = args.file - if task in ["translate", "distill", "correct", "summary"]: - responses = processor.route_task(task, content) - write_output(file, task, responses, metadata) - else: - print("Error - No task selected from: summary, translate, distill, correct") - except KeyboardInterrupt: - print("\nExiting...") - exit(0) - except Exception as e: - print(f"Fatal error: {str(e)}") - import traceback - traceback.print_exc() - exit(1) - +import os +import re +import json +import random +import requests +from pathlib import Path +from dataclasses import dataclass +from typing import Dict, List, Optional, Union +from urllib.parse import urlparse +from chunker_regex import chunk_regex +import threading +import time +from extractous import Extractor +import sys + +@dataclass +class LLMConfig: + """ Configuration for LLM processing. + """ + templates_directory: str + api_url: str + api_password: str + translation_language: str + text_completion: bool = False + gen_count: int = 500 #not used + temp: float = 0.2 + rep_pen: float = 1 + min_p: float = 0.02 + top_k: int = 0 + top_p: int = 1 + + @classmethod + def from_json(cls, path: str): + """ Load configuration from JSON file. + Expects a JSON object with the same field names as the class. + """ + with open(path) as f: + config_dict = json.load(f) + return cls(**config_dict) + +class LLMProcessor: + def __init__(self, config, task): + """ Initialize the LLM processor with given configuration. + """ + self.config = config + self.api_function_urls = { + "tokencount": "/api/extra/tokencount", + "interrogate": "/api/v1/generate", + "max_context_length": "/api/extra/true_max_context_length", + "check": "/api/extra/generate/check", + "abort": "/api/extra/abort", + "version": "/api/extra/version", + "model": "/api/v1/model", + "generate": "/api/v1/generate", + } + self._update_instructions() + self.templates_directory = config.templates_directory + self.api_url = config.api_url + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {config.api_password}", + } + self.genkey = self._create_genkey() + self.templates = self._get_templates() + self.model = self._get_model() + self.max_context = self._get_max_context_length() + self.generated = False + self.system_instruction = "You are a helpful assistant." + self.task = task + self.max_chunk = int((self.max_context // 2) *.9) # give room for template + self.max_length = self.max_context // 2 + + def _update_instructions(self): + """Update instructions based on current config""" + self.summary_instruction = "Extract the key points, themes and actions from the text succinctly without developing any conclusions or commentary." + self.translate_instruction = f"Translate the entire document into {self.config.translation_language}. Maintain linguistic flourish and authorial style as much as possible. Write the full contents without condensing the writing or modernizing the language." + self.distill_instruction = "Rewrite the text to be as concise as possible without losing meaning." + self.correct_instruction = "Correct any grammar, spelling, style, or format errors in the text. Do not alter the text or otherwise change the meaning or style." + + def update_config(self, new_config): + """Update config and refresh instructions""" + self.config = new_config + self._update_instructions() + self.templates_directory = config.templates_directory + self.api_url = config.api_url + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {config.api_password}", + } + self.genkey = self._create_genkey() + self.templates = self._get_templates() + self.model = self._get_model() + self.max_context = self._get_max_context_length() + self.generated = False + self.system_instruction = "You are a helpful assistant." + self.task = task + self.max_chunk = int((self.max_context // 2) *.9) # give room for template + self.max_length = self.max_context // 2 + + def _get_templates(self): + """ Look in the templates directory and load JSON template files. + Falls back to Alpaca format if no valid templates found. + """ + templates = {} + alpaca_template = { + "name": ["alpaca"], + "akas": [], + "system_start": "### System:", + "system_end": "\n", + "user_start": "### Human:", + "user_end": "\n", + "assistant_start": "### Assistant:", + "assistant_end": "\n" + } + try: + template_path = Path(self.templates_directory) + if not template_path.exists(): + return {"alpaca": alpaca_template} + for file in template_path.glob('*.json'): + try: + with open(file) as f: + template = json.load(f) + required_fields = [ + "system_start", "system_end", + "user_start", "user_end", + "assistant_start", "assistant_end" + ] + if all(field in template for field in required_fields): + base_name = file.stem + if "akas" not in template: + template["akas"] = [] + if "name" not in template: + template["name"] = [base_name] + templates[base_name] = template + except (json.JSONDecodeError, KeyError) as e: + print(f"Error loading template {file}: {str(e)}") + continue + if not templates: + return {"alpaca": alpaca_template} + return templates + except Exception as e: + print(f"Error loading templates directory: {str(e)}") + return {"alpaca": alpaca_template} + + def _get_model(self): + """ Queries Kobold for current model name and finds matching template. + Prefers exact matches, then version matches, then base model matches. + Exits the script if no match found. + Overly complicated. + """ + if self.config.text_completion: + print("Using text completion mode") + return { + "name": ["Completion"], + "user": "", + "assistant": "", + "system": None, + } + model_name = self._call_api("model") + if not model_name: + print("Could not get model name from API, exiting script") + sys.exit(1) + print(f"Kobold reports model: {model_name}") + + def normalize(s): + """ Remove special chars and lowercase for matching. + """ + return re.sub(r"[^a-z0-9]", "", s.lower()) + + model_name_normalized = normalize(model_name) + best_match = None + best_match_length = 0 + best_match_version = 0 + for template in self.templates.values(): + names_to_check = template.get("name", []) + if isinstance(names_to_check, str): + names_to_check = [names_to_check] + names_to_check.extend(template.get("akas", [])) + for name in names_to_check: + normalized_name = normalize(name) + if normalized_name in model_name_normalized: + version_match = re.search(r'(\d+)(?:\.(\d+))?', name) + current_version = float(f"{version_match.group(1)}.{version_match.group(2) or '0'}") if version_match else 0 + name_length = len(normalized_name) + if current_version > best_match_version or \ + (current_version == best_match_version and name_length > best_match_length): + best_match = template + best_match_length = name_length + best_match_version = current_version + if best_match: + print(f"Selected template: {best_match.get('name', ['Unknown'])[0]}") + return best_match + print(f"No version-specific template found, trying base model match...") + for template in self.templates.values(): + names_to_check = template.get("name", []) + if isinstance(names_to_check, str): + names_to_check = [names_to_check] + names_to_check.extend(template.get("akas", [])) + for name in names_to_check: + normalized_name = normalize(name) + base_name = re.sub(r'\d+(?:\.\d+)?', '', normalized_name) + if base_name in model_name_normalized: + name_length = len(base_name) + if name_length > best_match_length: + best_match = template + best_match_length = name_length + if best_match: + print(f"Selected base template: {best_match.get('name', ['Unknown'])[0]}") + return best_match + print("No matching template found, exiting script") + sys.exit(1) + + def _call_api(self, api_function, payload=None): + """ Call the Kobold API. + Some API calls are POSTs and some are GETs. + """ + if api_function not in self.api_function_urls: + raise ValueError(f"Invalid API function: {api_function}") + url = f"{self.api_url}{self.api_function_urls[api_function]}" + try: + + if api_function in ["tokencount", "generate", "check", "interrogate", "abort"]: + response = requests.post(url, json=payload, headers=self.headers) + result = response.json() + if api_function == "tokencount": + return int(result.get("value")) + elif api_function == "abort": + return result.get("success") + else: + return result["results"][0].get("text") + else: + response = requests.get(url, json=payload, headers=self.headers) + result = response.json() + if resulted := result.get("result", None): + return resulted + else: + return int(result.get("value", None)) + except requests.RequestException as e: + print(f"Error calling API: {str(e)}") + return None + + def _get_initial_chunk(self, content): + """ We are chunking based on natural break points. + Only works well for Germanic and Romance languages. + Ideally content is in markdown format. + """ + total_tokens = self._get_token_count(content) + print(f"Content tokens to chunk: {total_tokens}") + if total_tokens < self.max_chunk: + return content + matches = chunk_regex.finditer(content) + current_size = 0 + chunks = [] + for match in matches: + chunk = match.group(0) + chunk_size = self._get_token_count(chunk) + if current_size + chunk_size > self.max_chunk: + if not chunks: + chunks.append(chunk) + break + chunks.append(chunk) + current_size += chunk_size + return ''.join(chunks) + + def compose_prompt(self, instruction="", content=""): + """ Create the prompt that gets sent to the LLM and specify samplers. + """ + prompt = self.get_prompt(instruction, content) + payload = { + "prompt": prompt, + "max_length": self.max_length, + "genkey": self.genkey, + "top_p": self.config.top_p, + "top_k": self.config.top_k, + "temp": self.config.temp, + "rep_pen": self.config.rep_pen, + "min_p": self.config.min_p, + } + return payload + + def get_prompt(self, instruction="", content=""): + """ Create a prompt to send to the LLM using the instruct template + or basic text completion. + """ + if not self.model: + raise ValueError("No model template loaded") + if self.model["name"] == ["Completion"]: + return f"{content}".strip() + user_text = f"{content}{instruction}" + if not user_text: + raise ValueError("No user text provided (both instruction and content are empty)") + prompt_parts = [] + if self.model.get("system") is not None: + prompt_parts.extend([ + self.model["system_start"], + self.model["system_instruction"], + self.model["system_end"] + ]) + prompt_parts.extend([ + self.model["user_start"], + user_text, + self.model["user_end"], + self.model["assistant_start"] + ]) + return "".join(prompt_parts) + + def route_task(self, task="summary", content=""): + """ Send to appropriate function. + """ + if task in ["correct", "translate", "distill", "summary"]: + instruction = getattr(self, f"{task}_instruction") + responses = self.process_in_chunks(instruction, content) + return responses + else: + raise ValueError(f"Unknown task: {task}") + + def process_in_chunks(self, instruction="", content=""): + """ Process the content into chunks. + """ + chunks = [] + remaining = content + chunk_num = 0 + while remaining: + current_section = remaining[:45000] + remaining = remaining[45000:] + chunk = self._get_initial_chunk(current_section) + chunk_len = len(chunk) + #print(chunk_len) + if chunk_len == 0: + print("Warning: Got zero-length chunk") + continue + chunks.append(chunk) + remaining = current_section[len(chunk):].strip() + remaining + chunk_num += 1 + print(f"Chunk: {chunk_num}") + + responses = [] + total_chunks = len(chunks) + print("Starting chunk processing...") + for i, chunk in enumerate(chunks, 1): + chunk_tokens = self._get_token_count(chunk) + print(f"Chunk {i} of {total_chunks}, Size: {chunk_tokens}\n") + response = self.generate_with_status(self.compose_prompt( + instruction=instruction, + content=chunk + )) + if response: + responses.append(response) + return responses + + def generate_with_status(self, prompt): + """ Threads generation so that we can stream the output onto the + console otherwise we stare at a blank screen. + """ + self.generated = False + monitor = threading.Thread( + target=self._monitor_generation, + daemon=True + ) + monitor.start() + try: + result = self._call_api("generate", prompt) + self.generated = True + monitor.join() + return result + except Exception as e: + print(f"Generation error: {e}") + return None + + def _monitor_generation(self): + """ Write generation onto the terminal as it is created. + """ + generating = False + payload = { + 'genkey': self.genkey + } + while not self.generated: + result = self._call_api("check", payload) + if not result: + time.sleep(2) + continue + time.sleep(1) + clear_console() + print(f"{result}") + + @staticmethod + def _create_genkey(): + """ Create a unique generation key. + Prevents kobold from returning your generation to another query. + """ + return f"KCPP{''.join(str(random.randint(0, 9)) for _ in range(4))}" + + def _get_max_context_length(self): + """ Get the maximum context length from the API. + """ + max_context = self._call_api("max_context_length") + print(f"Model has maximum context length of: {max_context}") + return max_context + + def _get_token_count(self, content): + """ Get the token count for a piece of content. + """ + payload = {"prompt": content, "genkey": self.genkey} + return self._call_api("tokencount", payload) + + def _get_content(self, content): + """ Read text from a file to chunk. + """ + extractor = Extractor() + extractor.set_extract_string_max_length(100000000) + + result, metadata = extractor.extract_file_to_string(content) + print(len(result)) + print(metadata) + return result, metadata + +def check_api(api_url): + """ See if the API is ready + """ + url = f"{api_url}/api/v1/info/version/" + if requests.get(url, json="", headers=""): + return True + return False + +def write_output(output_path, task, responses, metadata): + """ Write the task response to a file. + """ + processing_time = metadata.get('Processing-Time', 0) + hours = int(processing_time // 3600) + minutes = int((processing_time % 3600) // 60) + seconds = int(processing_time % 60) + + try: + with open(output_path, 'w', encoding='utf-8') as f: + f.write(f"File: {metadata.get('resourceName', 'Unknown')}\n") + f.write(f"Type: {metadata.get('Content-Type', 'Unknown')}\n") + f.write(f"Encoding: {metadata.get('Content-Encoding', 'Unknown')}\n") + f.write(f"Length: {metadata.get('Content-Length', 'Unknown')}\n") + f.write(f"Total Time: {hours:02d}:{minutes:02d}:{seconds:02d}\n\n") + for response in responses: + f.write(f"{response}\n\n") + print(f"\nOutput written to: {output_path}") + except Exception as e: + print(f"Error writing output file: {str(e)}") + +def clear_console(): + """ Clears the screen so the output can refresh. + """ + command = 'clear' + if os.name in ('nt', 'dos'): + command = 'cls' + os.system(command) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="LLM Processor for Kobold API") + parser.add_argument("--config", type=str, help="Path to JSON config file") + parser.add_argument("--content", type=str, help="Content to process") + parser.add_argument("--api-url", type=str, default="http://localhost:5001", help="URL for the LLM API") + parser.add_argument("--api-password", type=str, default="", help="Password for the LLM API") + parser.add_argument("--templates", type=str, default="./templates", help="Directory for instruct templates") + parser.add_argument("--task", type=str, default="summary", help="Task: summary, translate, distill, correct") + parser.add_argument("--file", type=str, default="output.txt", help="Output to file path") + args = parser.parse_args() + try: + if args.config: + config = LLMConfig.from_json(args.config) + else: + config = LLMConfig( + templates_directory=args.templates, + api_url=args.api_url, + api_password=args.api_password, + translation_language="English" + ) + task = args.task.lower() + processor = LLMProcessor(config, task) + content, metadata = processor._get_content(args.content) + file = args.file + if task in ["translate", "distill", "correct", "summary"]: + start_time = time.time() + responses = processor.route_task(task, content) + metadata["Processing-Time"] = time.time() - start_time + write_output(file, task, responses, metadata) + else: + print("Error - No task selected from: summary, translate, distill, correct") + except KeyboardInterrupt: + print("\nExiting...") + exit(0) + except Exception as e: + print(f"Fatal error: {str(e)}") + import traceback + traceback.print_exc() + exit(1) +