Skip to content

QuixiAI/hessian-lr

Repository files navigation

HessianLR - Adaptive Learning Rate Scheduling via Hessian Estimation

Python 3.8+ PyTorch 1.9+ License: MIT

A PyTorch learning rate scheduler that uses second-order (Hessian) information to adaptively set learning rates based on the local curvature of the loss landscape.

✨ Key Features

  • 🧮 Mathematically Principled: Based on second-order Taylor expansion and Newton's method
  • Efficient: Only ~0.25% runtime overhead through smart Hessian-vector products
  • 🛡️ Robust: Handles negative curvature, numerical instabilities, and edge cases
  • 🔌 Easy Integration: Drop-in replacement for standard PyTorch schedulers
  • 🎯 Framework Support: Native integrations for Transformers, Lightning, and W&B

🚀 Quick Start

Installation

pip install hessian-lr

Basic Usage

import torch
from hessian_lr import HessianLR

# Initialize model and optimizer
model = YourModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Create the HessianLR scheduler
scheduler = HessianLR(
    optimizer,
    num_warmup_steps=100,    # Linear warmup
    update_period=10,        # Recalculate every 10 steps
    lr_bounds=(0.5, 2.0),    # Limit LR changes to 2x
    smoothing_factor=0.9     # EMA smoothing
)

# Training loop
for batch in dataloader:
    # Forward pass
    loss = model(batch)
    
    # CRITICAL: Use create_graph=True for second-order derivatives
    loss.backward(create_graph=True)
    
    # Step scheduler BEFORE optimizer
    scheduler.step(model=model, loss=loss)
    optimizer.step()
    optimizer.zero_grad()
    
    # Clean up computation graph to save memory
    del loss
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

🧠 How It Works

HessianLR estimates the optimal learning rate using a quadratic approximation of the loss function:

lr* = (g^T g) / (g^T H g)

Where:

  • g is the gradient vector
  • H is the Hessian matrix
  • g^T H g is computed efficiently using Hessian-vector products

This gives us the optimal step size for a quadratic approximation of the loss, automatically adapting to the local geometry:

  • Flat regions (low curvature) → larger learning rates
  • Sharp valleys (high curvature) → smaller learning rates

📊 Performance Characteristics

Computational Overhead

  • Runtime: ~0.25% with update_period=10
  • Memory: ~50% increase during HVP computation
  • Scalability: Tested on models up to 1B parameters

Convergence Benefits

  • Faster convergence in early training
  • Better final performance in many cases
  • More stable training dynamics
  • Automatic adaptation to loss landscape changes

⚙️ Configuration

Core Parameters

Parameter Default Description
num_warmup_steps 0 Linear warmup period before adaptive scheduling
update_period 10 Steps between LR recalculations (higher = more efficient)
lr_bounds (0.3, 3.0) Multiplicative bounds on LR changes per update
lr_min 1e-8 Absolute minimum learning rate
lr_max 1.0 Absolute maximum learning rate
smoothing_factor 0.9 EMA factor for LR smoothing (0-1, higher = smoother)
negative_curvature_decay 0.5 LR decay factor when curvature is negative/unreliable

Choosing Parameters

For stable training:

scheduler = HessianLR(
    optimizer,
    update_period=20,        # Less frequent updates
    lr_bounds=(0.5, 1.5),    # Conservative bounds
    smoothing_factor=0.95    # Heavy smoothing
)

For aggressive adaptation:

scheduler = HessianLR(
    optimizer,
    update_period=5,         # Frequent updates
    lr_bounds=(0.2, 5.0),    # Wide bounds
    smoothing_factor=0.7     # Less smoothing
)

🔌 Framework Integrations

Hugging Face Transformers

from hessian_lr.integrations.transformers import get_hessian_scheduler

# In your training script
scheduler = get_hessian_scheduler(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=10000  # Optional, for compatibility
)

# Or use with Trainer
from transformers import Trainer

class HessianTrainer(Trainer):
    def create_scheduler(self, num_training_steps, optimizer=None):
        return get_hessian_scheduler(
            self.optimizer if optimizer is None else optimizer,
            num_warmup_steps=self.args.warmup_steps
        )

PyTorch Lightning

from hessian_lr.integrations.lightning import HessianLRCallback

trainer = pl.Trainer(
    callbacks=[
        HessianLRCallback(
            scheduler_kwargs={
                'update_period': 10,
                'smoothing_factor': 0.9
            }
        )
    ]
)

Weights & Biases

from hessian_lr.integrations.wandb import WandbHessianLRLogger

# Initialize logger
wandb_logger = WandbHessianLRLogger(
    prefix="hessian_lr",
    log_frequency=10
)

# In training loop
wandb_logger.log_scheduler_step(scheduler, step=global_step)

📈 Memory Management

The create_graph=True requirement increases memory usage. Here are strategies to manage it:

1. Increase Update Period

# Only compute Hessian every 50 steps
scheduler = HessianLR(optimizer, update_period=50)

2. Gradient Accumulation Compatible

for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward(create_graph=(i % accumulation_steps == 0))
    
    if i % accumulation_steps == 0:
        scheduler.step(model=model, loss=loss)
        optimizer.step()
        optimizer.zero_grad()

3. Memory Monitoring

from hessian_lr.utils import estimate_memory_overhead

# Check memory requirements
memory_info = estimate_memory_overhead(model)
print(f"Estimated overhead: {memory_info['total_estimated_mb']:.1f} MB")

🧪 Testing & Validation

Run Tests

# Basic tests
pytest tests/

# With coverage
pytest tests/ --cov=hessian_lr

# Specific test
pytest tests/test_scheduler.py -k "test_warmup"

Validate Your Setup

from hessian_lr.utils import validate_scheduler_setup

# Check for common issues
warnings = validate_scheduler_setup(optimizer, model, loss)
for warning in warnings:
    print(f"Warning: {warning}")

📚 Examples

Complete examples in the examples/ directory:

  • basic_usage.py - Simple neural network training
  • transformers_example.py - Fine-tuning BERT with HessianLR
  • lightning_example.py - PyTorch Lightning integration
  • benchmarks.py - Performance comparisons

⚠️ Important Considerations

  1. Requires create_graph=True: This roughly doubles backward pass memory
  2. Call order matters: Always call scheduler.step() before optimizer.step()
  3. Not suitable for all models: Very large models may face memory constraints
  4. Best for smooth losses: Works best when loss landscape is reasonably smooth

🤝 Contributing

We welcome contributions! Please see our Contributing Guidelines.

Development Setup

git clone https://github.com/QuixiAI/hessian-lr
cd hessian-lr
pip install -e ".[dev]"
pre-commit install

📖 Citation

If you use HessianLR in your research, please cite:

@software{hessian_lr,
  title={HessianLR: Adaptive Learning Rate Scheduling via Hessian Estimation},
  author={Eric Hartford},
  year={2024},
  url={https://github.com/QuixiAI/hessian-lr}
}

🙏 Acknowledgments

  • Inspired by second-order optimization methods and AdaHessian
  • Built on PyTorch's excellent automatic differentiation
  • Thanks to the open-source community for feedback and contributions

📄 License

MIT License - see LICENSE file for details.


Questions? Open an issue or reach out on Twitter

Found it useful? Give us a ⭐ on GitHub!

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published