From 417b69e4e8687e52adaee19cbad986ff3ed10d07 Mon Sep 17 00:00:00 2001 From: levoz92 Date: Sun, 1 Dec 2024 21:46:17 -0500 Subject: [PATCH] Added further optimization features such as dynamic convergence check, dynamic learning rate scheduling and efficient data loading by prefetching batches --- src/qml_benchmarks/model_utils.py | 132 ++++++++++++++++++++++++++---- 1 file changed, 116 insertions(+), 16 deletions(-) diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index 312118e2..af1b94af 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -22,10 +22,74 @@ import optax import jax import jax.numpy as jnp +import multiprocessing as mp from sklearn.exceptions import ConvergenceWarning from sklearn.utils import gen_batches +class ConvergenceCriterion: + """ + Implements a dynamic convergence criterion for training loops. The purpose of this + class is to monitor the stability of the loss over a specified number of recent steps + (`patience`) and decide whether the model has converged. + + Convergence is determined by checking whether the range of loss values in the + most recent `patience` steps is smaller than a given `tolerance`. This approach + provides flexibility and robustness compared to fixed thresholds, adapting to the + model's training behavior dynamically. + + Attributes: + patience (int): Number of recent loss values to consider when evaluating convergence. + tolerance (float): The maximum allowable difference between the highest and lowest + loss values in the recent `patience` steps for convergence to be declared. + losses (list[float]): A list storing the loss values observed during training. + + Methods: + check_convergence(loss): + Adds a new loss value to the history and evaluates whether the training + process has converged based on the recent losses. + - Returns `True` if the difference between the maximum and minimum loss + in the recent `patience` steps is less than `tolerance`. + - Returns `False` otherwise. + + Example Usage: + # Initialize the convergence criterion + criterion = ConvergenceCriterion(patience=10, tolerance=0.001) + + # Check for convergence in a training loop + for step in range(max_steps): + loss = compute_loss() + if criterion.check_convergence(loss): + print(f"Converged at step {step}") + break + """ + def __init__(self, patience=10, tolerance=0.001): + self.patience = patience + self.tolerance = tolerance + self.losses = [] + + def check_convergence(self, loss): + """ + Checks whether the model's training loss has converged based on recent loss values. + + Args: + loss (float): The most recent loss value from the training loop. + + Returns: + bool: True if the recent loss values indicate convergence; False otherwise. + """ + self.losses.append(loss) + if len(self.losses) < self.patience: + return False + recent_losses = self.losses[-self.patience:] + if max(recent_losses) - min(recent_losses) < self.tolerance: + return True + return False + +# Dynamic learning rate adjustment +def learning_rate_schedule(initial_lr, step, decay_rate=0.1, decay_steps=1000): + return initial_lr * (decay_rate ** (step // decay_steps)) + def train( model, loss_fn, optimizer, X, y, random_key_generator, convergence_interval=200 ): @@ -58,11 +122,16 @@ def train( if not model.batch_size / model.max_vmap % 1 == 0: raise Exception("Batch size must be multiple of max_vmap.") + # dynamic learning rate initialization + learning_rate = learning_rate_schedule(model.learning_rate, step) params = model.params_ - opt = optimizer(learning_rate=model.learning_rate) + opt = optimizer(learning_rate=learning_rate) opt_state = opt.init(params) grad_fn = jax.grad(loss_fn) + # The convergence criterion is initialized before the loop with a specified patience and tolerance. + criterion = ConvergenceCriterion(patience=convergence_interval, tolerance=0.001) + # jitting through the chunked_grad function can take a long time, # so we jit here and chunk after if model.jit: @@ -83,7 +152,10 @@ def update(params, opt_state, x, y): loss_history = [] converged = False start = time.time() - for step in range(model.max_steps): + + batches = prefetch_batches(get_batch(X, y, random_key_generator, model.batch_size)) + for step, (X_batch, y_batch) in enumerate(batches): + params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch) key = random_key_generator() X_batch, y_batch = get_batch(X, y, key, batch_size=model.batch_size) params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch) @@ -93,22 +165,28 @@ def update(params, opt_state, x, y): if np.isnan(loss_val): logging.info(f"nan encountered. Training aborted.") break + + # the criterion is checked at each step: + if criterion.check_convergence(loss_val): + logging.info(f"Model {model.__class__.__name__} converged after {step} steps.") + converged = True + break # decide convergence - if step > 2 * convergence_interval: - # get means of last two intervals and standard deviation of last interval - average1 = np.mean(loss_history[-convergence_interval:]) - average2 = np.mean( - loss_history[-2 * convergence_interval : -convergence_interval] - ) - std1 = np.std(loss_history[-convergence_interval:]) - # if the difference in averages is small compared to the statistical fluctuations, stop training. - if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: - logging.info( - f"Model {model.__class__.__name__} converged after {step} steps." - ) - converged = True - break + # if step > 2 * convergence_interval: + # # get means of last two intervals and standard deviation of last interval + # average1 = np.mean(loss_history[-convergence_interval:]) + # average2 = np.mean( + # loss_history[-2 * convergence_interval : -convergence_interval] + # ) + # std1 = np.std(loss_history[-convergence_interval:]) + # # if the difference in averages is small compared to the statistical fluctuations, stop training. + # if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: + # logging.info( + # f"Model {model.__class__.__name__} converged after {step} steps." + # ) + # converged = True + # break end = time.time() loss_history = np.array(loss_history) @@ -144,6 +222,28 @@ def get_batch(X, y, rnd_key, batch_size=32): ) return X[rnd_indices], y[rnd_indices] +def prefetch_batches(batch_generator, queue_size=10): + """ + A prefetch mechanism for loading the next batch while the current batch is being processed. + + Args: + batch_generator (generator): A generator that yields batches of data + queue_size (int): The maximum number of batches to prefetch + + Returns: + A generator that yields batches of data + """ + queue = mp.Queue(maxsize=queue_size) + def loader(): + for batch in batch_generator: + queue.put(batch) + queue.put(None) + mp.Process(target=loader).start() + while True: + batch = queue.get() + if batch is None: + break + yield batch def get_from_dict(dict, key_list): """