Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 116 additions & 16 deletions src/qml_benchmarks/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,74 @@
import optax
import jax
import jax.numpy as jnp
import multiprocessing as mp
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils import gen_batches


class ConvergenceCriterion:
"""
Implements a dynamic convergence criterion for training loops. The purpose of this
class is to monitor the stability of the loss over a specified number of recent steps
(`patience`) and decide whether the model has converged.

Convergence is determined by checking whether the range of loss values in the
most recent `patience` steps is smaller than a given `tolerance`. This approach
provides flexibility and robustness compared to fixed thresholds, adapting to the
model's training behavior dynamically.

Attributes:
patience (int): Number of recent loss values to consider when evaluating convergence.
tolerance (float): The maximum allowable difference between the highest and lowest
loss values in the recent `patience` steps for convergence to be declared.
losses (list[float]): A list storing the loss values observed during training.

Methods:
check_convergence(loss):
Adds a new loss value to the history and evaluates whether the training
process has converged based on the recent losses.
- Returns `True` if the difference between the maximum and minimum loss
in the recent `patience` steps is less than `tolerance`.
- Returns `False` otherwise.

Example Usage:
# Initialize the convergence criterion
criterion = ConvergenceCriterion(patience=10, tolerance=0.001)

# Check for convergence in a training loop
for step in range(max_steps):
loss = compute_loss()
if criterion.check_convergence(loss):
print(f"Converged at step {step}")
break
"""
def __init__(self, patience=10, tolerance=0.001):
self.patience = patience
self.tolerance = tolerance
self.losses = []

def check_convergence(self, loss):
"""
Checks whether the model's training loss has converged based on recent loss values.

Args:
loss (float): The most recent loss value from the training loop.

Returns:
bool: True if the recent loss values indicate convergence; False otherwise.
"""
self.losses.append(loss)
if len(self.losses) < self.patience:
return False
recent_losses = self.losses[-self.patience:]
if max(recent_losses) - min(recent_losses) < self.tolerance:
return True
return False

# Dynamic learning rate adjustment
def learning_rate_schedule(initial_lr, step, decay_rate=0.1, decay_steps=1000):
return initial_lr * (decay_rate ** (step // decay_steps))

def train(
model, loss_fn, optimizer, X, y, random_key_generator, convergence_interval=200
):
Expand Down Expand Up @@ -58,11 +122,16 @@ def train(
if not model.batch_size / model.max_vmap % 1 == 0:
raise Exception("Batch size must be multiple of max_vmap.")

# dynamic learning rate initialization
learning_rate = learning_rate_schedule(model.learning_rate, step)
params = model.params_
opt = optimizer(learning_rate=model.learning_rate)
opt = optimizer(learning_rate=learning_rate)
opt_state = opt.init(params)
grad_fn = jax.grad(loss_fn)

# The convergence criterion is initialized before the loop with a specified patience and tolerance.
criterion = ConvergenceCriterion(patience=convergence_interval, tolerance=0.001)

# jitting through the chunked_grad function can take a long time,
# so we jit here and chunk after
if model.jit:
Expand All @@ -83,7 +152,10 @@ def update(params, opt_state, x, y):
loss_history = []
converged = False
start = time.time()
for step in range(model.max_steps):

batches = prefetch_batches(get_batch(X, y, random_key_generator, model.batch_size))
for step, (X_batch, y_batch) in enumerate(batches):
params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch)
key = random_key_generator()
X_batch, y_batch = get_batch(X, y, key, batch_size=model.batch_size)
params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch)
Expand All @@ -93,22 +165,28 @@ def update(params, opt_state, x, y):
if np.isnan(loss_val):
logging.info(f"nan encountered. Training aborted.")
break

# the criterion is checked at each step:
if criterion.check_convergence(loss_val):
logging.info(f"Model {model.__class__.__name__} converged after {step} steps.")
converged = True
break

# decide convergence
if step > 2 * convergence_interval:
# get means of last two intervals and standard deviation of last interval
average1 = np.mean(loss_history[-convergence_interval:])
average2 = np.mean(
loss_history[-2 * convergence_interval : -convergence_interval]
)
std1 = np.std(loss_history[-convergence_interval:])
# if the difference in averages is small compared to the statistical fluctuations, stop training.
if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2:
logging.info(
f"Model {model.__class__.__name__} converged after {step} steps."
)
converged = True
break
# if step > 2 * convergence_interval:
# # get means of last two intervals and standard deviation of last interval
# average1 = np.mean(loss_history[-convergence_interval:])
# average2 = np.mean(
# loss_history[-2 * convergence_interval : -convergence_interval]
# )
# std1 = np.std(loss_history[-convergence_interval:])
# # if the difference in averages is small compared to the statistical fluctuations, stop training.
# if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2:
# logging.info(
# f"Model {model.__class__.__name__} converged after {step} steps."
# )
# converged = True
# break

end = time.time()
loss_history = np.array(loss_history)
Expand Down Expand Up @@ -144,6 +222,28 @@ def get_batch(X, y, rnd_key, batch_size=32):
)
return X[rnd_indices], y[rnd_indices]

def prefetch_batches(batch_generator, queue_size=10):
"""
A prefetch mechanism for loading the next batch while the current batch is being processed.

Args:
batch_generator (generator): A generator that yields batches of data
queue_size (int): The maximum number of batches to prefetch

Returns:
A generator that yields batches of data
"""
queue = mp.Queue(maxsize=queue_size)
def loader():
for batch in batch_generator:
queue.put(batch)
queue.put(None)
mp.Process(target=loader).start()
while True:
batch = queue.get()
if batch is None:
break
yield batch

def get_from_dict(dict, key_list):
"""
Expand Down