A PyTorch learning rate scheduler that uses second-order (Hessian) information to adaptively set learning rates based on the local curvature of the loss landscape.
- 🧮 Mathematically Principled: Based on second-order Taylor expansion and Newton's method
- ⚡ Efficient: Only ~0.25% runtime overhead through smart Hessian-vector products
- 🛡️ Robust: Handles negative curvature, numerical instabilities, and edge cases
- 🔌 Easy Integration: Drop-in replacement for standard PyTorch schedulers
- 🎯 Framework Support: Native integrations for Transformers, Lightning, and W&B
pip install hessian-lr
import torch
from hessian_lr import HessianLR
# Initialize model and optimizer
model = YourModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Create the HessianLR scheduler
scheduler = HessianLR(
optimizer,
num_warmup_steps=100, # Linear warmup
update_period=10, # Recalculate every 10 steps
lr_bounds=(0.5, 2.0), # Limit LR changes to 2x
smoothing_factor=0.9 # EMA smoothing
)
# Training loop
for batch in dataloader:
# Forward pass
loss = model(batch)
# CRITICAL: Use create_graph=True for second-order derivatives
loss.backward(create_graph=True)
# Step scheduler BEFORE optimizer
scheduler.step(model=model, loss=loss)
optimizer.step()
optimizer.zero_grad()
# Clean up computation graph to save memory
del loss
if torch.cuda.is_available():
torch.cuda.empty_cache()
HessianLR estimates the optimal learning rate using a quadratic approximation of the loss function:
lr* = (g^T g) / (g^T H g)
Where:
g
is the gradient vectorH
is the Hessian matrixg^T H g
is computed efficiently using Hessian-vector products
This gives us the optimal step size for a quadratic approximation of the loss, automatically adapting to the local geometry:
- Flat regions (low curvature) → larger learning rates
- Sharp valleys (high curvature) → smaller learning rates
- Runtime: ~0.25% with
update_period=10
- Memory: ~50% increase during HVP computation
- Scalability: Tested on models up to 1B parameters
- Faster convergence in early training
- Better final performance in many cases
- More stable training dynamics
- Automatic adaptation to loss landscape changes
Parameter | Default | Description |
---|---|---|
num_warmup_steps |
0 | Linear warmup period before adaptive scheduling |
update_period |
10 | Steps between LR recalculations (higher = more efficient) |
lr_bounds |
(0.3, 3.0) | Multiplicative bounds on LR changes per update |
lr_min |
1e-8 | Absolute minimum learning rate |
lr_max |
1.0 | Absolute maximum learning rate |
smoothing_factor |
0.9 | EMA factor for LR smoothing (0-1, higher = smoother) |
negative_curvature_decay |
0.5 | LR decay factor when curvature is negative/unreliable |
For stable training:
scheduler = HessianLR(
optimizer,
update_period=20, # Less frequent updates
lr_bounds=(0.5, 1.5), # Conservative bounds
smoothing_factor=0.95 # Heavy smoothing
)
For aggressive adaptation:
scheduler = HessianLR(
optimizer,
update_period=5, # Frequent updates
lr_bounds=(0.2, 5.0), # Wide bounds
smoothing_factor=0.7 # Less smoothing
)
from hessian_lr.integrations.transformers import get_hessian_scheduler
# In your training script
scheduler = get_hessian_scheduler(
optimizer,
num_warmup_steps=500,
num_training_steps=10000 # Optional, for compatibility
)
# Or use with Trainer
from transformers import Trainer
class HessianTrainer(Trainer):
def create_scheduler(self, num_training_steps, optimizer=None):
return get_hessian_scheduler(
self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.warmup_steps
)
from hessian_lr.integrations.lightning import HessianLRCallback
trainer = pl.Trainer(
callbacks=[
HessianLRCallback(
scheduler_kwargs={
'update_period': 10,
'smoothing_factor': 0.9
}
)
]
)
from hessian_lr.integrations.wandb import WandbHessianLRLogger
# Initialize logger
wandb_logger = WandbHessianLRLogger(
prefix="hessian_lr",
log_frequency=10
)
# In training loop
wandb_logger.log_scheduler_step(scheduler, step=global_step)
The create_graph=True
requirement increases memory usage. Here are strategies to manage it:
# Only compute Hessian every 50 steps
scheduler = HessianLR(optimizer, update_period=50)
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward(create_graph=(i % accumulation_steps == 0))
if i % accumulation_steps == 0:
scheduler.step(model=model, loss=loss)
optimizer.step()
optimizer.zero_grad()
from hessian_lr.utils import estimate_memory_overhead
# Check memory requirements
memory_info = estimate_memory_overhead(model)
print(f"Estimated overhead: {memory_info['total_estimated_mb']:.1f} MB")
# Basic tests
pytest tests/
# With coverage
pytest tests/ --cov=hessian_lr
# Specific test
pytest tests/test_scheduler.py -k "test_warmup"
from hessian_lr.utils import validate_scheduler_setup
# Check for common issues
warnings = validate_scheduler_setup(optimizer, model, loss)
for warning in warnings:
print(f"Warning: {warning}")
Complete examples in the examples/
directory:
basic_usage.py
- Simple neural network trainingtransformers_example.py
- Fine-tuning BERT with HessianLRlightning_example.py
- PyTorch Lightning integrationbenchmarks.py
- Performance comparisons
- Requires
create_graph=True
: This roughly doubles backward pass memory - Call order matters: Always call
scheduler.step()
beforeoptimizer.step()
- Not suitable for all models: Very large models may face memory constraints
- Best for smooth losses: Works best when loss landscape is reasonably smooth
We welcome contributions! Please see our Contributing Guidelines.
git clone https://github.com/QuixiAI/hessian-lr
cd hessian-lr
pip install -e ".[dev]"
pre-commit install
If you use HessianLR in your research, please cite:
@software{hessian_lr,
title={HessianLR: Adaptive Learning Rate Scheduling via Hessian Estimation},
author={Eric Hartford},
year={2024},
url={https://github.com/QuixiAI/hessian-lr}
}
- Inspired by second-order optimization methods and AdaHessian
- Built on PyTorch's excellent automatic differentiation
- Thanks to the open-source community for feedback and contributions
MIT License - see LICENSE file for details.
Questions? Open an issue or reach out on Twitter
Found it useful? Give us a ⭐ on GitHub!