diff --git a/.gitignore b/.gitignore index 5e48521b8..379db96d5 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,8 @@ autoemulate/experimental/exploratory/data case_studies/**/data/ # Benchmarking data -benchmarks/data/ \ No newline at end of file +benchmarks/data/ + +# Ignore experimental outputs +autoemulate/experimental/outputs/ +outputs \ No newline at end of file diff --git a/autoemulate/emulators/base.py b/autoemulate/emulators/base.py index c352f4eef..ecab6d7f7 100644 --- a/autoemulate/emulators/base.py +++ b/autoemulate/emulators/base.py @@ -560,8 +560,8 @@ class PyTorchBackend(nn.Module, Emulator): epochs: int = 10 loss_history: ClassVar[list[float]] = [] verbose: bool = False - loss_fn: nn.Module = nn.MSELoss() - optimizer_cls: type[optim.Optimizer] = optim.Adam + loss_fn: nn.Module + optimizer_cls: type[optim.Optimizer] optimizer: optim.Optimizer supports_grad: bool = True lr: float = 1e-1 @@ -608,9 +608,10 @@ def _fit(self, x: TensorLike, y: TensorLike): # Track loss epoch_loss += loss.item() batches += 1 + # Update learning rate if scheduler is defined if self.scheduler is not None: - self.scheduler.step() # type: ignore[call-arg] + self.scheduler.step() # Average loss for the epoch avg_epoch_loss = epoch_loss / batches diff --git a/autoemulate/experimental/README.md b/autoemulate/experimental/README.md new file mode 100644 index 000000000..f9a9df15c --- /dev/null +++ b/autoemulate/experimental/README.md @@ -0,0 +1,478 @@ +# Spatiotemporal Emulation Framework + +A unified YAML-based framework for training neural emulators on spatiotemporal data using The Well. + +## Quick Start + +```bash +# Run with example configs +python autoemulate/experimental/run_the_well_experiment.py \ + --config autoemulate/experimental/configs/turbulent_radiative_layer_2d.yaml + +# Create a template config +python autoemulate/experimental/run_the_well_experiment.py --create-example +``` + +## Data Sources (Auto-detected) + +The script automatically detects which data source to use: + +### 1. The Well Native Datasets +```yaml +data: + well_dataset_name: "turbulent_radiative_layer_2D" # Triggers The Well native loading + data_path: "../data/the_well/datasets" + n_steps_input: 4 + n_steps_output: 1 + batch_size: 4 +``` + +**Example:** `configs/turbulent_radiative_layer_2d.yaml` + +### 2. File-Based Data +```yaml +data: + data_path: "./data/bout" # Path to HDF5/PyTorch files + dataset_type: "bout" + n_steps_input: 4 + n_steps_output: 10 + batch_size: 4 +``` + +**Example:** `configs/bout.yaml` + +### 3. Generated Data (from Simulators) +```yaml +simulator: + type: "reaction_diffusion" # or advection_diffusion + parameters_range: + feed_rate: [0.02, 0.06] + kill_rate: [0.045, 0.065] + n: 64 + T: 100.0 + dt: 1.0 + +data: + n_train_samples: 200 + n_valid_samples: 20 + n_test_samples: 20 + n_steps_input: 4 + n_steps_output: 10 + batch_size: 4 +``` + +**Examples:** `configs/reaction_diffusion_generated.yaml`, `configs/advection_diffusion_generated.yaml` + +## Example Configurations + +All configs are in `autoemulate/experimental/configs/`: + +| Config | Data Source | Description | +|--------|-------------|-------------| +| `turbulent_radiative_layer_2d.yaml` | The Well Native | Turbulent flows with radiation | +| `bout.yaml` | File | BOUT++ plasma simulation data | +| `advection_diffusion_generated.yaml` | Generated | Transport phenomena (nu, mu parameters) | +| `reaction_diffusion_generated.yaml` | Generated | Gray-Scott pattern formation (F, k parameters) | + +## Configuration Structure + +```yaml +# Basic info +experiment_name: "my_experiment" +description: "What this experiment does" + +# Emulator +emulator_type: "the_well_fno" # FNO, AFNO, UNet variants +formatter_type: "default_channels_first" + +# Model architecture +model_params: + modes1: 16 # Fourier modes (x-direction) + modes2: 16 # Fourier modes (y-direction) + width: 32 # Hidden channels + n_blocks: 4 # Number of FNO blocks + +# Training +trainer: + epochs: 100 + device: "mps" # "cuda", "mps", or "cpu" + optimizer_type: "adam" + optimizer_params: + lr: 0.001 + enable_amp: false # Mixed precision (CUDA only) + +# Teacher forcing (optional) +teacher_forcing: + enabled: true + schedule: + - start_epoch: 0 + end_epoch: 50 + weight: 1.0 + - start_epoch: 50 + end_epoch: 100 + weight: 0.5 + +# Paths +paths: + output_dir: "./outputs/my_experiment" + model_save_path: "./outputs/my_experiment/artifacts/final_model.pt" + +# Logging +log_level: "INFO" +verbose: true +``` + +## Emulator Types + +- `the_well_fno` - Fourier Neural Operator (recommended) +- `the_well_afno` - Adaptive FNO +- `the_well_unet_classic` - Classic U-Net +- `the_well_unet_convnext` - ConvNext U-Net + +## Device Support + +### CPU +```yaml +trainer: + device: "cpu" +``` + +### Apple Silicon (MPS) +```yaml +trainer: + device: "mps" +``` + +### NVIDIA GPU (CUDA) +```yaml +trainer: + device: "cuda" + enable_amp: true # Enable mixed precision + amp_type: "float16" +``` + +## Output Structure + +``` +outputs/experiment_name/ +├── config.yaml # Configuration used +├── logs/ +│ └── experiment_*.log # Timestamped logs +├── checkpoints/ +│ └── checkpoint_*.pt # Training checkpoints +└── artifacts/ + └── final_model.pt # Final trained model +``` + +## Common Tasks + +### Quick Test Run +Set low epochs for testing: +```yaml +trainer: + epochs: 10 +``` + +### Adjust Memory Usage +Change batch size: +```yaml +data: + batch_size: 8 # Increase if you have memory +``` + +### Change Dataset Size (Generated Data) +```yaml +data: + n_train_samples: 500 # More data = better emulator + n_valid_samples: 50 + n_test_samples: 50 +``` + +## Adding New Datasets + +> **Note:** Adding datasets requires editing source Python files (`config_models.py`, `run_the_well_experiment.py`, and your dataset class). This ensures type safety and explicit registration. + +### For File-Based Data + +If you have pre-existing HDF5 or PyTorch tensor files: + +**1. Prepare your data in the expected format:** +- Shape: `[n_trajectories, n_timesteps, width, height, n_channels]` +- Save as HDF5 with key `"data"` or PyTorch `.pt` file + +**2. Create a dataset class** (optional, if custom loading needed): + +```python +# In autoemulate/experimental/data/spatiotemporal_dataset.py +class MyCustomDataset(AutoEmulateDataset): + def read_data(self, data_path: str): + """Load your custom data format.""" + # Load from your custom format + data = load_my_custom_format(data_path) + # Convert to expected shape + self.data = torch.tensor(data, dtype=self.dtype) +``` + +**3. Register dataset type:** + +Add to `DatasetType` enum in `config_models.py`: +```python +class DatasetType(str, Enum): + MY_DATASET = "my_dataset" +``` + +**4. Register in script:** + +Add to `get_dataset_class()` in `run_the_well_experiment.py`: +```python +def get_dataset_class(dataset_type: DatasetType): + dataset_classes = { + DatasetType.MY_DATASET: MyCustomDataset, + # ... existing datasets + } +``` + +Also import at the top of `run_the_well_experiment.py`: +```python +from autoemulate.experimental.data.spatiotemporal_dataset import ( + MyCustomDataset, + # ... other imports +) +``` + +**5. Create config YAML:** +```yaml +data: + data_path: "./data/my_dataset" + dataset_type: "my_dataset" + n_steps_input: 4 + n_steps_output: 10 + batch_size: 4 +``` + +**Summary of files to edit:** `config_models.py` (enum), `run_the_well_experiment.py` (register + import), `spatiotemporal_dataset.py` (class) + +### For Generated Data (Simulators) + +> **Note:** Adding simulators requires editing source Python files (`config_models.py`, `run_the_well_experiment.py`, and creating your simulator class). + +**1. Create simulator class:** + +```python +# In autoemulate/simulations/my_simulator.py +from autoemulate.simulations.base import Simulator + +class MySimulator(Simulator): + def __init__(self, parameters_range, output_names, + return_timeseries=False, n=64, T=10.0, dt=0.1): + super().__init__(parameters_range, output_names) + self.return_timeseries = return_timeseries + self.n = n + self.T = T + self.dt = dt + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """Run simulation for parameters in x.""" + # Your simulation code here + result = run_my_simulation(x, self.n, self.T, self.dt) + return torch.tensor(result).reshape(1, -1) + + def forward_samples_spatiotemporal(self, n: int, + random_seed: int | None = None): + """Generate n samples and reshape to spatiotemporal format.""" + # Generate samples + X, y = self.forward_samples(n, random_seed) + + # Reshape to [n_samples, n_timesteps, width, height, n_channels] + # ... your reshaping code + + return { + "data": reshaped_data, + "constant_scalars": None, + "constant_fields": None, + } +``` + +**2. Add to `SimulatorType` enum:** + +In `config_models.py`: +```python +class SimulatorType(str, Enum): + MY_SIMULATOR = "my_simulator" +``` + +**3. Register in `create_simulator()`:** + +In `run_the_well_experiment.py`: +```python +def create_simulator(config: ExperimentConfig): + sim_cfg = config.simulator + + if sim_cfg.type.value == "my_simulator": + return MySimulator( + parameters_range=sim_cfg.parameters_range, + output_names=sim_cfg.output_names, + return_timeseries=sim_cfg.return_timeseries, + n=sim_cfg.n, + T=sim_cfg.T, + dt=sim_cfg.dt, + ) +``` + +Also import at the top of `run_the_well_experiment.py`: +```python +from autoemulate.simulations.my_simulator import MySimulator +``` + +**4. Create config YAML:** +```yaml +simulator: + type: "my_simulator" + parameters_range: + param1: [min, max] + param2: [min, max] + n: 64 + T: 10.0 + dt: 0.1 + +data: + n_train_samples: 200 + n_valid_samples: 20 + n_test_samples: 20 + n_steps_input: 4 + n_steps_output: 10 + batch_size: 4 +``` + +**Summary of files to edit:** `config_models.py` (enum), `run_the_well_experiment.py` (register + import), `my_simulator.py` (class) + +## Adding New Emulators + +> **Note:** Adding emulators requires editing source Python files (`config_models.py`, `run_the_well_experiment.py`, and your emulator class). This ensures type safety and validation. + +**1. Create emulator class:** + +```python +# In autoemulate/experimental/emulators/the_well.py +from the_well.benchmark import models +from typing import ClassVar + +class TheWellMyModel(TheWellEmulator): + """My custom emulator using The Well framework.""" + + # Specify the model class from The Well or your own + model_cls: type[torch.nn.Module] = models.FNO # or your custom model + + # Define default model parameters + model_parameters: ClassVar[ModelParams] = { + "modes1": 16, + "modes2": 16, + "width": 32, + # ... other model-specific params + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) +``` + +**For custom PyTorch models not in The Well:** + +```python +class MyCustomModel(torch.nn.Module): + """Your custom neural network architecture.""" + + def __init__(self, in_channels, out_channels, width=32, **kwargs): + super().__init__() + # Define your architecture + self.conv1 = torch.nn.Conv2d(in_channels, width, 3, padding=1) + # ... more layers + + def forward(self, x): + # Your forward pass + return output + +class TheWellMyCustomModel(TheWellEmulator): + model_cls: type[torch.nn.Module] = MyCustomModel + model_parameters: ClassVar[ModelParams] = { + "width": 32, + } +``` + +**2. Add to `EmulatorType` enum:** + +In `config_models.py`: +```python +class EmulatorType(str, Enum): + THE_WELL_MY_MODEL = "the_well_my_model" +``` + +**3. Register in `create_emulator()`:** + +In `run_the_well_experiment.py`: +```python +def create_emulator(config: ExperimentConfig, datamodule): + emulator_classes = { + EmulatorType.THE_WELL_MY_MODEL: TheWellMyModel, + # ... existing emulators + } +``` + +**4. Import the emulator:** + +In `run_the_well_experiment.py` at the top: +```python +from autoemulate.experimental.emulators.the_well import ( + TheWellMyModel, + # ... other imports +) +``` + +**5. Create config YAML:** +```yaml +emulator_type: "the_well_my_model" +formatter_type: "default_channels_first" + +model_params: + modes1: 16 + modes2: 16 + width: 32 + # ... other model-specific parameters + +trainer: + epochs: 100 + device: "cuda" + optimizer_type: "adam" + optimizer_params: + lr: 0.001 +``` + +**Summary of files to edit:** `config_models.py` (enum), `run_the_well_experiment.py` (register + import), `the_well.py` (class) + +**Tips:** +- Inherit from `TheWellEmulator` to get training loop, checkpointing, and logging for free +- Set `model_parameters` as class variable for defaults that can be overridden in config +- Use existing formatters (`default_channels_first`, `default_channels_first_with_time`) or create custom ones +- Available base models in The Well: `FNO`, `AFNO`, `UNet`, `ConvNextUNet` + +## Troubleshooting + +**CUDA out of memory:** Reduce `data.batch_size` or `model_params.width` + +**Slow training:** Enable AMP on CUDA: `trainer.enable_amp: true` + +**Data not found:** Check paths in config match your directory structure + +## Design Philosophy + +**Why explicit registration?** +The framework uses explicit registration (editing Python files) rather than auto-discovery to ensure: +- ✅ Type safety through Pydantic validation +- ✅ Clear mapping of what's available +- ✅ IDE support for autocomplete and type checking +- ✅ No "magic" imports or hidden behavior + +**Future improvement:** A decorator-based registry system could allow adding custom models without editing core files while maintaining most benefits. + +## Further Information + +See `configs/README.md` for detailed config documentation. diff --git a/autoemulate/experimental/config_models.py b/autoemulate/experimental/config_models.py new file mode 100644 index 000000000..b47837389 --- /dev/null +++ b/autoemulate/experimental/config_models.py @@ -0,0 +1,401 @@ +"""Pydantic models for configuring spatiotemporal emulation experiments.""" + +from enum import Enum +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, Field, field_validator + + +class SimulatorType(str, Enum): + """Available simulator types.""" + + ADVECTION_DIFFUSION = "advection_diffusion" + REACTION_DIFFUSION = "reaction_diffusion" + + +class DataSourceType(str, Enum): + """Data source type - determines how data is loaded.""" + + GENERATED = "generated" # Generate data from simulator + WELL_NATIVE = "well_native" # Use The Well's native datasets + FILE = "file" # Load from existing files (HDF5/PyTorch) + + +class DatasetType(str, Enum): + """Available dataset types.""" + + ADVECTION_DIFFUSION = "advection_diffusion" + REACTION_DIFFUSION = "reaction_diffusion" + MHD = "mhd" + BOUT = "bout" + GENERIC = "generic" + + +class EmulatorType(str, Enum): + """Available emulator types.""" + + THE_WELL_FNO = "the_well_fno" + THE_WELL_FNO_WITH_TIME = "the_well_fno_with_time" + THE_WELL_FNO_WITH_LEARNABLE_WEIGHTS = "the_well_fno_with_learnable_weights" + THE_WELL_AFNO = "the_well_afno" + THE_WELL_UNET_CLASSIC = "the_well_unet_classic" + THE_WELL_UNET_CONVNEXT = "the_well_unet_convnext" + + +class FormatterType(str, Enum): + """Available data formatter types.""" + + DEFAULT_CHANNELS_FIRST = "default_channels_first" + DEFAULT_CHANNELS_FIRST_WITH_TIME = "default_channels_first_with_time" + + +class OptimizerType(str, Enum): + """Available optimizer types.""" + + ADAM = "adam" + ADAMW = "adamw" + SGD = "sgd" + RMSPROP = "rmsprop" + + +class LRSchedulerType(str, Enum): + """Available learning rate scheduler types.""" + + STEP_LR = "step_lr" + EXPONENTIAL_LR = "exponential_lr" + COSINE_ANNEALING_LR = "cosine_annealing_lr" + REDUCE_LR_ON_PLATEAU = "reduce_lr_on_plateau" + + +class LossFunctionType(str, Enum): + """Available loss function types.""" + + VRMSE = "vrmse" + MSE = "mse" + MAE = "mae" + + +class SimulatorConfig(BaseModel): + """Configuration for simulator.""" + + type: SimulatorType = Field( + default=SimulatorType.ADVECTION_DIFFUSION, + description="Type of simulator to use", + ) + parameters_range: dict[str, tuple[float, float]] = Field( + default={"nu": (0.0001, 0.01), "mu": (0.5, 2.0)}, + description="Parameter ranges for the simulator", + ) + output_names: list[str] = Field( + default=["solution"], description="Names of output variables" + ) + return_timeseries: bool = Field( + default=True, description="Whether to return full timeseries" + ) + n: int = Field(default=64, description="Number of spatial points per direction") + L: float = Field(default=10.0, description="Domain size in X and Y directions") + T: float = Field(default=10.0, description="Total simulation time") + dt: float = Field(default=0.1, description="Time step size") + + +class DataConfig(BaseModel): + """Configuration for data generation and loading.""" + + # Data source - automatically determined from config fields + source_type: DataSourceType | None = Field( + default=None, + description="Data source type (auto-detected if None)", + ) + + # Data paths + data_path: str | None = Field( + default=None, + description="Path to data directory (for file loading) or Well datasets base path (for well_native)", # noqa: E501 + ) + + # The Well native dataset configuration + well_dataset_name: str | None = Field( + default=None, + description="Name of The Well dataset (e.g., 'turbulent_radiative_layer_2D'). Set this to use Well native datasets.", # noqa: E501 + ) + + # Data generation (for generated source type) + n_train_samples: int = Field( + default=200, description="Number of training samples to generate" + ) + n_valid_samples: int = Field( + default=4, description="Number of validation samples to generate" + ) + n_test_samples: int = Field( + default=4, description="Number of test samples to generate" + ) + random_seed: int | None = Field( + default=None, description="Random seed for data generation" + ) + + # Dataset configuration + dataset_type: DatasetType = Field( + default=DatasetType.ADVECTION_DIFFUSION, description="Type of dataset" + ) + n_steps_input: int = Field(default=4, description="Number of input time steps") + n_steps_output: int = Field(default=10, description="Number of output time steps") + stride: int = Field(default=1, description="Stride for sampling the data") + input_channel_idxs: tuple[int, ...] | None = Field( + default=None, description="Indices of input channels to use" + ) + output_channel_idxs: tuple[int, ...] | None = Field( + default=None, description="Indices of output channels to use" + ) + + # DataLoader configuration + batch_size: int = Field(default=4, description="Batch size for DataLoader") + dtype: str = Field(default="float32", description="Data type (float32 or float64)") + use_normalization: bool = Field( + default=False, + description="Whether to use Z-score normalization ((x - mean) / std). " + "Stats computed from training data and applied to all splits.", + ) + + def get_source_type(self) -> DataSourceType: + """Automatically determine data source type from configuration.""" + if self.source_type is not None: + return self.source_type + + # Auto-detect based on fields + if self.well_dataset_name is not None: + return DataSourceType.WELL_NATIVE + if self.data_path is not None: + return DataSourceType.FILE + return DataSourceType.GENERATED + + +class TeacherForcingConfig(BaseModel): + """Configuration for teacher forcing schedule.""" + + start: float = Field( + default=1.0, description="Starting teacher forcing ratio", ge=0.0, le=1.0 + ) + end: float = Field( + default=0.0, description="Ending teacher forcing ratio", ge=0.0, le=1.0 + ) + schedule_epochs: int | None = Field( + default=None, + description="Number of epochs for schedule (None = use total epochs)", + ) + schedule_type: str = Field( + default="linear", + description="Schedule type: linear, exponential, or step", + ) + mode: str = Field( + default="mix", + description="Mode: mix (tempering) or prob (stochastic)", + ) + min_prob: float = Field( + default=1e-6, description="Minimum probability for numerical stability" + ) + + +class ModelParamsConfig(BaseModel): + """Configuration for model-specific parameters.""" + + # FNO parameters + modes1: int = Field(default=16, description="FNO modes in first spatial dimension") + modes2: int = Field(default=16, description="FNO modes in second spatial dimension") + modes3: int = Field( + default=16, description="FNO modes in third spatial dimension (3D only)" + ) + modes_time: int = Field( + default=3, description="FNO modes in time dimension (for FNOWithTime)" + ) + hidden_channels: int = Field(default=64, description="Number of hidden channels") + gradient_checkpointing: bool = Field( + default=False, description="Enable gradient checkpointing" + ) + + # AFNO parameters + hidden_dim: int = Field(default=64, description="AFNO hidden dimension") + n_blocks: int = Field(default=14, description="AFNO number of blocks") + + # UNet parameters + init_features: int = Field(default=48, description="UNet initial features") + blocks_per_stage: int = Field( + default=2, description="UNet blocks per stage (ConvNext variant)" + ) + + # Allow extra fields for custom model parameters + model_config = {"extra": "allow"} + + +class TrainerConfig(BaseModel): + """Configuration for training.""" + + # Optimizer configuration + optimizer_type: OptimizerType = Field( + default=OptimizerType.ADAM, description="Optimizer type" + ) + optimizer_params: dict[str, Any] = Field( + default_factory=lambda: {"lr": 1e-3}, + description="Optimizer parameters (lr, weight_decay, etc.)", + ) + + # LR Scheduler configuration + lr_scheduler_type: LRSchedulerType | None = Field( + default=None, description="Learning rate scheduler type" + ) + lr_scheduler_params: dict[str, Any] = Field( + default_factory=dict, + description="LR scheduler parameters", + ) + + # Training parameters + epochs: int = Field(default=10, description="Number of training epochs") + checkpoint_frequency: int = Field( + default=5, description="Save checkpoint every N epochs" + ) + val_frequency: int = Field(default=1, description="Validate every N epochs") + rollout_val_frequency: int = Field( + default=1, description="Rollout validation every N epochs" + ) + + # Rollout parameters + max_rollout_steps: int = Field( + default=100, description="Maximum rollout steps for validation" + ) + short_validation_length: int = Field( + default=20, description="Short validation length" + ) + make_rollout_videos: bool = Field( + default=True, description="Generate rollout videos" + ) + num_time_intervals: int = Field( + default=5, description="Number of time intervals for metrics" + ) + + # Mixed precision + enable_amp: bool = Field( + default=False, description="Enable automatic mixed precision" + ) + amp_type: str | None = Field( + default=None, description="AMP data type (float16 or bfloat16)" + ) + + # Distributed training + is_distributed: bool = Field( + default=False, description="Enable distributed training" + ) + + # Device + device: str = Field( + default="cpu", description="Device to use (cpu, cuda, mps, cuda:0, etc.)" + ) + + # Teacher forcing + enable_tf_schedule: bool = Field( + default=False, description="Enable scheduled teacher forcing in training" + ) + tf_params: TeacherForcingConfig = Field( + default_factory=TeacherForcingConfig, + description="Teacher forcing schedule parameters", + ) + + # Loss function + loss_fn: LossFunctionType = Field( + default=LossFunctionType.VRMSE, description="Loss function to use" + ) + + # Checkpoint path for resuming + checkpoint_path: str | None = Field( + default=None, description="Path to checkpoint to resume from" + ) + + +class PathsConfig(BaseModel): + """Configuration for input/output paths.""" + + output_dir: Path = Field( + default=Path("./outputs"), description="Base output directory" + ) + data_save_path: Path | None = Field( + default=None, description="Path to save generated data (HDF5 or PT format)" + ) + model_save_path: Path | None = Field( + default=None, description="Path to save trained model state dict" + ) + save_format: str = Field( + default="h5", + description="Data save format: h5 (HDF5) or pt (PyTorch)", + ) + + @field_validator("output_dir", mode="before") + @classmethod + def ensure_path(cls, v): + """Convert string to Path.""" + return Path(v) if isinstance(v, str) else v + + +class ExperimentConfig(BaseModel): + """Top-level configuration for a spatiotemporal emulation experiment.""" + + # Experiment metadata + experiment_name: str = Field( + default="spatiotemporal_experiment", description="Name of the experiment" + ) + description: str = Field(default="", description="Description of the experiment") + + # Core components + emulator_type: EmulatorType = Field( + default=EmulatorType.THE_WELL_FNO, description="Type of emulator to use" + ) + formatter_type: FormatterType = Field( + default=FormatterType.DEFAULT_CHANNELS_FIRST, + description="Type of data formatter to use", + ) + + # Configuration sections + simulator: SimulatorConfig | None = Field( + default=None, + description="Simulator configuration (required if generating data)", + ) + data: DataConfig = Field( + default_factory=DataConfig, description="Data configuration" + ) + model_params: ModelParamsConfig = Field( + default_factory=ModelParamsConfig, + description="Model-specific parameters", + ) + trainer: TrainerConfig = Field( + default_factory=TrainerConfig, description="Training configuration" + ) + paths: PathsConfig = Field( + default_factory=PathsConfig, description="Path configuration" + ) + + # Logging and monitoring + verbose: bool = Field(default=False, description="Enable verbose logging") + log_level: str = Field( + default="INFO", + description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", + ) + + model_config = {"arbitrary_types_allowed": True} + + def save_to_yaml(self, path: str | Path): + """Save configuration to YAML file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + # Convert to dict and handle Path objects + config_dict = self.model_dump(mode="json") + + with open(path, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False) + + @classmethod + def load_from_yaml(cls, path: str | Path) -> "ExperimentConfig": + """Load configuration from YAML file.""" + with open(path) as f: + config_dict = yaml.safe_load(f) + + return cls(**config_dict) diff --git a/autoemulate/experimental/configs/README.md b/autoemulate/experimental/configs/README.md new file mode 100644 index 000000000..cabcb4c1d --- /dev/null +++ b/autoemulate/experimental/configs/README.md @@ -0,0 +1,98 @@ +# Example Configurations + +Example YAML configurations for different data sources and emulation scenarios. + +## Available Configs + +### `turbulent_radiative_layer_2d.yaml` +- **Data:** The Well native dataset +- **Use:** Turbulent flows with radiation +- **Device:** MPS (Apple Silicon) +- **Samples:** 6,984 train / 873 validation / 873 test + +```bash +python autoemulate/experimental/run_the_well_experiment.py \ + --config autoemulate/experimental/configs/turbulent_radiative_layer_2d.yaml +``` + +### `bout.yaml` +- **Data:** File-based (BOUT++ plasma simulations) +- **Use:** Plasma physics, fusion research +- **Device:** MPS +- **Location:** `./autoemulate/experimental/exploratory/data/bout/` + +```bash +python autoemulate/experimental/run_the_well_experiment.py \ + --config autoemulate/experimental/configs/bout.yaml +``` + +### `advection_diffusion_generated.yaml` +- **Data:** Generated from simulator +- **Use:** Transport phenomena, fluid dynamics +- **Device:** MPS +- **Parameters:** nu (viscosity) [0.0001, 0.01], mu (advection) [0.5, 2.0] +- **Samples:** 200 train / 20 validation / 20 test + +```bash +python autoemulate/experimental/run_the_well_experiment.py \ + --config autoemulate/experimental/configs/advection_diffusion_generated.yaml +``` + +### `reaction_diffusion_generated.yaml` +- **Data:** Generated from simulator +- **Use:** Pattern formation, chemical dynamics (Gray-Scott model) +- **Device:** MPS +- **Parameters:** feed_rate [0.02, 0.06], kill_rate [0.045, 0.065] +- **Samples:** 200 train / 20 validation / 20 test + +```bash +python autoemulate/experimental/run_the_well_experiment.py \ + --config autoemulate/experimental/configs/reaction_diffusion_generated.yaml +``` + +## Quick Customizations + +### Change Device +```yaml +trainer: + device: "cuda" # or "mps" or "cpu" +``` + +### Enable Mixed Precision (CUDA only) +```yaml +trainer: + enable_amp: true + amp_type: "float16" +``` + +### Quick Test (Fewer Epochs) +```yaml +trainer: + epochs: 10 +``` + +### Adjust Memory Usage +```yaml +data: + batch_size: 8 # Increase if you have GPU memory +``` + +## Data Source Detection + +The script auto-detects the data source: + +1. **The Well Native** - If `data.well_dataset_name` is set +2. **File-Based** - If `data.data_path` is set (without `well_dataset_name`) +3. **Generated** - If `simulator` section exists + +## Creating Your Own Config + +Copy an existing config and modify: + +1. Change `experiment_name` and `description` +2. Adjust data source (see parent README.md) +3. Tune model parameters (`modes1`, `modes2`, `width`) +4. Set training parameters (`epochs`, `device`, `lr`) +5. Run! + +See `../README.md` for full configuration documentation. diff --git a/autoemulate/experimental/configs/advection_diffusion_generated.yaml b/autoemulate/experimental/configs/advection_diffusion_generated.yaml new file mode 100644 index 000000000..35d64212f --- /dev/null +++ b/autoemulate/experimental/configs/advection_diffusion_generated.yaml @@ -0,0 +1,108 @@ +# Configuration for Advection-Diffusion dataset with data generation +# This uses the unified run_the_well_experiment.py script with GENERATED data source + +# Basic experiment information +experiment_name: "advection_diffusion_generated" +description: "FNO emulator for advection-diffusion dynamics with data generation" + +# Emulator configuration +emulator_type: "the_well_fno" + +# Data formatter +formatter_type: "default_channels_first" + +# Model architecture parameters +model_params: + modes1: 16 # Fourier modes in first dimension + modes2: 16 # Fourier modes in second dimension + hidden_channels: 128 # Number of channels in the FNO layers + # width: 32 # Width of the FNO layers + # n_blocks: 4 # Number of FNO blocks + +# Simulator configuration for advection-diffusion +# This triggers GENERATED data source type +simulator: + type: "advection_diffusion" + n: 64 # Spatial resolution (64x64 grid) + T: 10.0 # Total simulation time + dt: 0.1 # Time step size + L: 10.0 # Domain size + parameters_range: + nu: [0.0001, 0.01] # Diffusion coefficient (viscosity) + mu: [0.5, 2.0] # Advection strength + output_names: ["solution"] + return_timeseries: true + +# Data configuration - Generated from simulator +data: + # No well_dataset_name - not a Well native dataset + # No data_path - triggers GENERATED source type + + # Data generation + n_train_samples: 100 + n_valid_samples: 20 + n_test_samples: 20 + random_seed: 42 + use_normalization: true + + # Dataset configuration + dataset_type: "advection_diffusion" + n_steps_input: 4 # Number of input timesteps + n_steps_output: 1 # Number of output timesteps (rollout length) + batch_size: 4 # Batch size for training + stride: 1 # Stride for sampling + +# Trainer configuration +trainer: + epochs: 100 + device: "cuda" # Use "cuda" for NVIDIA GPU, "mps" for Apple Silicon, "cpu" for CPU + + optimizer_type: "adamw" + optimizer_params: + lr: 0.01 # Learning rate + weight_decay: 0.0001 + + lr_scheduler_type: "step_lr" + lr_scheduler_params: + step_size: 30 + gamma: 0.5 + + loss_fn: "vrmse" + + val_frequency: 5 + checkpoint_frequency: 20 + rollout_val_frequency: 10 + max_rollout_steps: 100 + short_validation_length: 10 + make_rollout_videos: true + + # Teacher forcing schedule + enable_tf_schedule: true + tf_params: + start: 1.0 + end: 0.0 + schedule_epochs: 50 + schedule_type: "linear" + mode: "mix" + min_prob: 0.0 + + # AMP settings + enable_amp: false + amp_type: null + + # Distributed training + is_distributed: false + checkpoint_path: null + num_time_intervals: 1 + +# Paths +paths: + output_dir: "./outputs/2025-10-22/advection_diffusion_generated" + model_save_path: "./outputs/advection_diffusion_generated/artifacts/final_model.pt" + # Optionally save generated data for reuse + data_save_path: "./outputs/2025-10-22/data/advection_diffusion_generated" + save_format: "h5" + +# Logging +log_level: "INFO" +verbose: true diff --git a/autoemulate/experimental/configs/bout.yaml b/autoemulate/experimental/configs/bout.yaml new file mode 100644 index 000000000..7c72798c5 --- /dev/null +++ b/autoemulate/experimental/configs/bout.yaml @@ -0,0 +1,92 @@ +# Configuration for BOUT++ dataset using file-based data loading +# This uses the unified run_the_well_experiment.py script + +# Basic experiment information +experiment_name: "bout_fno" +description: "FNO emulator for BOUT++ plasma simulation data" + +# Emulator configuration +emulator_type: "the_well_fno" + +# Data formatter +formatter_type: "default_channels_first" + +# Model architecture parameters +model_params: + modes1: 12 # Fourier modes in first dimension + modes2: 12 # Fourier modes in second dimension + width: 32 # Width of the FNO layers + n_blocks: 4 # Number of FNO blocks + +# Data configuration - File-based loading +# Setting data_path without well_dataset_name triggers FILE source type +data: + data_path: "./autoemulate/experimental/exploratory/data/bout" # Path to BOUT++ data directory with train/valid/test subdirs + # Structure expected: + # ./data/bout/train/data.pt + # ./data/bout/valid/data.pt + # ./data/bout/test/data.pt + + dataset_type: "bout" # Use BOUT dataset class + n_steps_input: 5 # Number of input timesteps + n_steps_output: 5 # Number of output timesteps (rollout length) + batch_size: 1 # Batch size for training + stride: 1 # Stride for sampling + +# No simulator section needed for file-based data! +# No well_dataset_name needed - this triggers file loading + +# Trainer configuration +trainer: + epochs: 100 + device: "cuda" + + optimizer_type: "adam" + optimizer_params: + lr: 0.0001 # Lower learning rate (1e-4) as used in the notebook + weight_decay: 0.0 + + lr_scheduler_type: "step_lr" + lr_scheduler_params: + step_size: 30 + gamma: 0.5 + + loss_fn: "vrmse" # Or "rmse" or "mse" + + val_frequency: 5 + checkpoint_frequency: 10 + rollout_val_frequency: 10 + max_rollout_steps: 100 + short_validation_length: 10 + make_rollout_videos: true + + # Teacher forcing schedule + enable_tf_schedule: true + tf_params: + start: 1.0 # Full teacher forcing at start + end: 0.0 # No teacher forcing at end + schedule_epochs: 50 # Decay over first 50 epochs + schedule_type: "linear" + mode: "mix" + min_prob: 0.0 + + # AMP settings (optional - enable for faster training on CUDA) + enable_amp: false + amp_type: null + + # Distributed training + is_distributed: false + checkpoint_path: null + num_time_intervals: 1 + +# Paths +paths: + output_dir: "./outputs/bout_fno" + model_save_path: "./outputs/bout_fno/artifacts/final_model.pt" + # Optionally save data in different formats + # data_save_path: "./outputs/bout_fno/data" + # save_format: "h5" + +# Logging +log_level: "INFO" +verbose: true diff --git a/autoemulate/experimental/configs/reaction_diffusion_generated.yaml b/autoemulate/experimental/configs/reaction_diffusion_generated.yaml new file mode 100644 index 000000000..830b0ebc8 --- /dev/null +++ b/autoemulate/experimental/configs/reaction_diffusion_generated.yaml @@ -0,0 +1,113 @@ +# Configuration for Reaction-Diffusion dataset (Gray-Scott model) +# This uses the unified run_the_well_experiment.py script with GENERATED data source + +# Basic experiment information +experiment_name: "reaction_diffusion_generated" +description: "FNO emulator for reaction-diffusion (Gray-Scott) dynamics with data generation" + +# Emulator configuration +emulator_type: "the_well_fno" + +# Data formatter +formatter_type: "default_channels_first" + +# Model architecture parameters +model_params: + modes1: 16 # Fourier modes in first dimension + modes2: 16 # Fourier modes in second dimension + hidden_channels: 128 + # width: 32 # Width of the FNO layers + # n_blocks: 4 # Number of FNO blocks + +# Simulator configuration for reaction-diffusion +# This triggers GENERATED data source type +simulator: + type: "reaction_diffusion" + n: 64 # Spatial resolution (64x64 grid) + T: 100.0 # Total simulation time + dt: 1.0 # Time step size + L: 2.5 # Domain size + parameters_range: + # Gray-Scott model parameters + feed_rate: [0.02, 0.06] # Feed rate (F) + kill_rate: [0.045, 0.065] # Kill rate (k) + output_names: ["u", "v"] # Chemical species U and V + return_timeseries: true + +# Data configuration - Generated from simulator +data: + # No well_dataset_name - not a Well native dataset + # No data_path - triggers GENERATED source type + + # Data generation + n_train_samples: 200 + n_valid_samples: 20 + n_test_samples: 20 + random_seed: 42 + use_normalization: true + + # Dataset configuration + dataset_type: "reaction_diffusion" + n_steps_input: 4 # Number of input timesteps + n_steps_output: 1 # Number of output timesteps (rollout length) + batch_size: 4 # Batch size for training + stride: 1 # Stride for sampling + + # Optional: channel indices (if you want to use only U or only V) + # input_channel_idxs: [0] # Only use U channel + # output_channel_idxs: [0] # Only predict U channel + +# Trainer configuration +trainer: + epochs: 100 + device: "cuda" + + optimizer_type: "adamw" + optimizer_params: + lr: 0.01 # Learning rate + weight_decay: 0.0001 + + lr_scheduler_type: "step_lr" + lr_scheduler_params: + step_size: 30 + gamma: 0.5 + + loss_fn: "vrmse" # VRMSE accounts for varying magnitudes over time + + val_frequency: 5 + checkpoint_frequency: 10 + rollout_val_frequency: 10 + max_rollout_steps: 100 + short_validation_length: 10 + make_rollout_videos: true + + # Teacher forcing schedule - helpful for reaction-diffusion + enable_tf_schedule: true + tf_params: + start: 1.0 # Full teacher forcing at start + end: 0.0 # No teacher forcing at end + schedule_epochs: 50 # Decay over first 50 epochs + schedule_type: "linear" + mode: "mix" + min_prob: 0.0 + + # AMP settings (optional - for CUDA) + enable_amp: false + amp_type: null + + # Distributed training + is_distributed: false + checkpoint_path: null + num_time_intervals: 1 + +# Paths +paths: + output_dir: "./outputs/2025-10-22/reaction_diffusion_generated" + model_save_path: "./outputs/reaction_diffusion_generated/artifacts/final_model.pt" + # Optionally save generated data for reuse + data_save_path: "./outputs/2025-10-22/data/reaction_diffusion_generated" + save_format: "h5" # or "pt" for PyTorch format + +# Logging +log_level: "INFO" +verbose: true diff --git a/autoemulate/experimental/configs/turbulent_radiative_layer_2d.yaml b/autoemulate/experimental/configs/turbulent_radiative_layer_2d.yaml new file mode 100644 index 000000000..e63c4e786 --- /dev/null +++ b/autoemulate/experimental/configs/turbulent_radiative_layer_2d.yaml @@ -0,0 +1,80 @@ +# Configuration for turbulent_radiative_layer_2D using The Well's native dataset +# This uses the unified run_the_well_experiment.py script + +# Basic experiment information +experiment_name: "turbulent_radiative_layer_2d" +description: "FNO emulator for turbulent radiative layer 2D using native Well dataset" + +# Emulator configuration +emulator_type: "the_well_fno" + +# Data formatter +formatter_type: "default_channels_first" + +# Model architecture parameters +model_params: + modes1: 16 + modes2: 16 + hidden_channels: 128 + # width: 20 + # n_blocks: 4 + +# Data configuration - The Well native dataset +# Setting well_dataset_name automatically uses WELL_NATIVE source type +data: + well_dataset_name: "turbulent_radiative_layer_2D" # This triggers native Well dataset loading + data_path: "./autoemulate/experimental/exploratory/data/the_well/datasets" # Optional: base path to Well datasets + n_steps_input: 4 + n_steps_output: 1 + batch_size: 4 + use_normalization: true + +# No simulator section needed for native Well datasets! + +# Trainer configuration +trainer: + epochs: 200 + device: "cuda" + + optimizer_type: "adamw" + optimizer_params: + lr: 0.01 + weight_decay: 0.0001 + + lr_scheduler_type: "step_lr" + lr_scheduler_params: + step_size: 30 + gamma: 0.5 + + loss_fn: "vrmse" + + val_frequency: 5 + checkpoint_frequency: 20 + rollout_val_frequency: 10 + max_rollout_steps: 100 + short_validation_length: 10 + make_rollout_videos: false + + enable_tf_schedule: true + tf_params: + start: 1.0 + end: 0.0 + schedule_epochs: 50 + schedule_type: "linear" + mode: "mix" + min_prob: 0.0 + + enable_amp: false + amp_type: null + is_distributed: false + checkpoint_path: null + num_time_intervals: 1 + +# Paths +paths: + output_dir: "./outputs/2025-10-22/turbulent_radiative_layer_2d" + model_save_path: "./outputs/2025-10-22/turbulent_radiative_layer_2d/artifacts/final_model.pt" + +# Logging +log_level: "INFO" +verbose: true diff --git a/autoemulate/experimental/data/spatiotemporal_dataset.py b/autoemulate/experimental/data/spatiotemporal_dataset.py index e6510d49a..9e40d0719 100644 --- a/autoemulate/experimental/data/spatiotemporal_dataset.py +++ b/autoemulate/experimental/data/spatiotemporal_dataset.py @@ -8,6 +8,133 @@ from torch.utils.data import DataLoader, Dataset +class SimpleZScoreNormalization: + """Simple Z-score normalization for AutoEmulate datasets. + + Computes mean and std from the data and applies (x - mean) / std normalization. + Compatible with The Well's Trainer interface. + """ + + def __init__(self, data: torch.Tensor, min_std: float = 1e-6): + """Initialize normalization with data statistics. + + Parameters + ---------- + data : torch.Tensor + Training data tensor of shape [N, T, W, H, C] + min_std : float + Minimum standard deviation to avoid division by zero + """ + # Compute statistics over all dimensions except channels + # Shape: [C] - one mean/std per channel + self.mean = data.mean(dim=(0, 1, 2, 3)) + self.std = data.std(dim=(0, 1, 2, 3)).clamp(min=min_std) + self.min_std = min_std + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + """Normalize input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape [..., C] + + Returns + ------- + torch.Tensor + Normalized tensor + """ + return (x - self.mean.to(x.device)) / self.std.to(x.device) + + def denormalize(self, x: torch.Tensor) -> torch.Tensor: + """Denormalize input tensor. + + Parameters + ---------- + x : torch.Tensor + Normalized tensor of shape [..., C] + + Returns + ------- + torch.Tensor + Denormalized tensor + """ + return x * self.std.to(x.device) + self.mean.to(x.device) + + def normalize_flattened( + self, + x: torch.Tensor, + mode: str = "variable", # noqa: ARG002 + ) -> torch.Tensor: + """Normalize tensor where fields are flattened as channels. + + Compatible with The Well's Trainer interface. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape [..., C] + mode : str + Mode for normalization ("variable" or "constant"). + Currently both use same stats. + + Returns + ------- + torch.Tensor + Normalized tensor + """ + return self.normalize(x) + + def denormalize_flattened( + self, + x: torch.Tensor, + mode: str = "variable", # noqa: ARG002 + ) -> torch.Tensor: + """Denormalize tensor where fields are flattened as channels. + + Compatible with The Well's Trainer interface. + + Parameters + ---------- + x : torch.Tensor + Normalized tensor of shape [..., C] + mode : str + Mode for denormalization ("variable" or "constant"). + Currently both use same stats. + + Returns + ------- + torch.Tensor + Denormalized tensor + """ + return self.denormalize(x) + + def delta_denormalize_flattened( + self, + x: torch.Tensor, + mode: str = "variable", # noqa: ARG002 + ) -> torch.Tensor: + """Delta denormalize tensor (for delta models). + + For simple Z-score normalization, delta denormalization only + requires std scaling (no mean shift needed). + + Parameters + ---------- + x : torch.Tensor + Normalized delta tensor of shape [..., C] + mode : str + Mode for denormalization ("variable" or "constant"). + Currently both use same stats. + + Returns + ------- + torch.Tensor + Denormalized delta tensor + """ + return x * self.std.to(x.device) + + class AutoEmulateDataset(Dataset): """A class for spatio-temporal datasets.""" @@ -24,6 +151,8 @@ def __init__( full_trajectory_mode: bool = False, dtype: torch.dtype = torch.float32, verbose: bool = False, + use_normalization: bool = False, + norm: SimpleZScoreNormalization | None = None, ): """ Initialize the dataset. @@ -50,9 +179,15 @@ def __init__( Data type for tensors. Defaults to torch.float32. verbose: bool If True, print dataset information. + use_normalization: bool + Whether to apply Z-score normalization. Defaults to False. + norm: SimpleZScoreNormalization | None + Normalization object (computed from training data). Defaults to None. """ self.dtype = dtype self.verbose = verbose + self.use_normalization = use_normalization + self.norm = norm # Read or parse data self.read_data(data_path) if data_path is not None else self.parse_data(data) @@ -178,9 +313,12 @@ def __len__(self): # noqa: D105 return len(self.all_input_fields) def __getitem__(self, idx): # noqa: D105 + input_fields = self.all_input_fields[idx] + output_fields = self.all_output_fields[idx] + item = { - "input_fields": self.all_input_fields[idx], - "output_fields": self.all_output_fields[idx], + "input_fields": input_fields, + "output_fields": output_fields, } if len(self.all_constant_scalars) > 0: item["constant_scalars"] = self.all_constant_scalars[idx] @@ -222,8 +360,6 @@ def __init__( n_steps_per_trajectory=[self.data.shape[1]] * self.data.shape[0], grid_type="cartesian", ) - self.use_normalization = False - self.norm = None class AdvectionDiffusionDataset(AutoEmulateDataset): @@ -249,8 +385,6 @@ def __init__( n_steps_per_trajectory=[self.data.shape[1]] * self.data.shape[0], grid_type="cartesian", ) - self.use_normalization = False - self.norm = None class BOUTDataset(AutoEmulateDataset): @@ -279,8 +413,6 @@ def __init__( n_steps_per_trajectory=[self.data.shape[1]] * self.data.shape[0], grid_type="cartesian", ) - self.use_normalization = False - self.norm = None # class AutoEmulateDataModule(AbstractDataModule): @@ -302,14 +434,19 @@ def __init__( dtype: torch.dtype = torch.float32, ftype: str = "torch", verbose: bool = False, + use_normalization: bool = False, ): self.verbose = verbose + self.use_normalization = use_normalization + base_path = Path(data_path) if data_path is not None else None suffix = ".pt" if ftype == "torch" else ".h5" fname = f"data{suffix}" train_path = base_path / "train" / fname if base_path is not None else None valid_path = base_path / "valid" / fname if base_path is not None else None test_path = base_path / "test" / fname if base_path is not None else None + + # Create training dataset first (without normalization) self.train_dataset = dataset_cls( data_path=str(train_path) if train_path is not None else None, data=data["train"] if data is not None else None, @@ -320,8 +457,25 @@ def __init__( output_channel_idxs=output_channel_idxs, dtype=dtype, verbose=self.verbose, + use_normalization=False, # Temporarily disable to compute stats + norm=None, ) - self.valid_dataset = dataset_cls( + + # Compute normalization from training data if requested + norm = None + if self.use_normalization: + if self.verbose: + print("Computing normalization statistics from training data...") + norm = SimpleZScoreNormalization(self.train_dataset.data) + if self.verbose: + print(f" Mean (per channel): {norm.mean}") + print(f" Std (per channel): {norm.std}") + + # Now enable normalization for training dataset + self.train_dataset.use_normalization = True + self.train_dataset.norm = norm + + self.val_dataset = dataset_cls( data_path=str(valid_path) if valid_path is not None else None, data=data["valid"] if data is not None else None, n_steps_input=n_steps_input, @@ -331,6 +485,8 @@ def __init__( output_channel_idxs=output_channel_idxs, dtype=dtype, verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, ) self.test_dataset = dataset_cls( data_path=str(test_path) if test_path is not None else None, @@ -342,6 +498,8 @@ def __init__( output_channel_idxs=output_channel_idxs, dtype=dtype, verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, ) self.rollout_val_dataset = dataset_cls( data_path=str(train_path) if train_path is not None else None, @@ -354,6 +512,8 @@ def __init__( full_trajectory_mode=True, dtype=dtype, verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, ) self.rollout_test_dataset = dataset_cls( data_path=str(test_path) if test_path is not None else None, @@ -366,6 +526,8 @@ def __init__( full_trajectory_mode=True, dtype=dtype, verbose=self.verbose, + use_normalization=self.use_normalization, + norm=norm, ) self.batch_size = batch_size @@ -378,7 +540,7 @@ def train_dataloader(self) -> DataLoader: def val_dataloader(self) -> DataLoader: """DataLoader for standard validation (not full trajectory rollouts).""" return DataLoader( - self.valid_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1 + self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1 ) def rollout_val_dataloader(self) -> DataLoader: diff --git a/autoemulate/experimental/emulators/fno.py b/autoemulate/experimental/emulators/fno.py index be246f80f..201d09706 100644 --- a/autoemulate/experimental/emulators/fno.py +++ b/autoemulate/experimental/emulators/fno.py @@ -1,7 +1,12 @@ +import inspect + import torch from autoemulate.core.types import OutputLike, TensorLike from autoemulate.experimental.emulators.spatiotemporal import SpatioTemporalEmulator from neuralop.models import FNO +from the_well.benchmark.metrics import VRMSE +from the_well.data.datasets import WellMetadata +from torch import nn from torch.utils.data import DataLoader @@ -16,6 +21,8 @@ def prepare_batch(sample, channels=(0,), with_constants=True, with_time=False): :, :, :, :, channels ] # [batch, time, height, width, len(channels)] + batch_size = x.shape[0] + # Permute both x and y x = x.permute(0, 4, 1, 2, 3) # [batch, len(channels), time, height, width] y = y.permute(0, 4, 1, 2, 3) # [batch, len(channels), time, height, width] @@ -27,8 +34,8 @@ def prepare_batch(sample, channels=(0,), with_constants=True, with_time=False): n_constants = constant_scalars.shape[-1] # Add spatio-temporal dims to constants - c_broadcast = constant_scalars.reshape(1, n_constants, 1, 1, 1).expand( - 1, n_constants, time_window, height, width + c_broadcast = constant_scalars.reshape(batch_size, n_constants, 1, 1, 1).expand( + batch_size, n_constants, time_window, height, width ) # Concatenate along channel dimension @@ -45,7 +52,16 @@ class FNOEmulator(SpatioTemporalEmulator): """An FNO emulator.""" def __init__( - self, x=None, y=None, channels: tuple[int, ...] = (0,), *args, **kwargs + self, + x=None, + y=None, + channels: tuple[int, ...] = (0,), + metadata: WellMetadata | None = None, + with_constants: bool = False, + with_time: bool = False, + n_steps_output: int = 1, + loss_fn: nn.Module | None = None, + **kwargs, ): _, _ = x, y # Unused # Ensure parent initialisers run before creating nn.Module attributes @@ -53,6 +69,18 @@ def __init__( self.model = FNO(**kwargs) self.channels = channels self.optimizer = torch.optim.Adam(self.model.parameters()) + self.n_epochs = 10 + self.metadata = metadata + self.loss_fn = loss_fn or VRMSE() + self.with_time = with_time + self.with_constants = with_constants + self.n_steps_output = n_steps_output + if ( + "metadata" in inspect.signature(self.model.__call__).parameters + and metadata is None + ): + msg = "metadata must be provided if model requires it" + raise ValueError(msg) @staticmethod def is_multioutput() -> bool: # noqa: D102 @@ -64,24 +92,32 @@ def _fit( assert isinstance(x, DataLoader), "x currently must be a DataLoader" assert y is None, "y currently must be None" - for idx, batch in enumerate(x): - # Prepare input with constants - x, y = prepare_batch( - batch, channels=self.channels, with_constants=True, with_time=True - ) # type: ignore # noqa: PGH003 + for epoch in range(self.n_epochs): + for idx, batch in enumerate(x): + # print(batch["input_fields"].shape) + # Prepare input with constants + # print(batch["input_fields"].shape, batch["output_fields"].shape) + x_batch, y_batch = prepare_batch( + batch, + channels=self.channels, + with_constants=self.with_constants, + with_time=self.with_time, + ) + + # Predictions + y_pred = self.model(x_batch) - # Predictions - y_pred = self.model(x) + if self.with_time: + y_pred = y_pred[:, :, self.n_steps_output, ...] # B, C, T, ... - # Get loss - # Take the first time idx as the next time step prediction - loss = self.loss_fn(y_pred[:, :, :1, ...], y) + # Get loss + loss = self.loss_fn(y_pred, y_batch, self.metadata).mean() - loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() - print(f"sample {idx:5d}, loss: {loss.item():.5e}") + print(f"epoch {epoch}, sample {idx:5d}, loss: {loss.item():.5e}") def forward(self, x: TensorLike): """Forward pass.""" @@ -95,8 +131,13 @@ def _predict(self, x: TensorLike | DataLoader, with_grad: bool) -> OutputLike: for _, batch in enumerate(x): # Prepare input with constants x, _ = prepare_batch( - batch, channels=channels, with_constants=True, with_time=True + batch, + channels=channels, + with_constants=self.with_constants, + with_time=self.with_time, ) out = self(x) + if self.with_time: + out = out[:, :, : self.n_steps_output, ...] # B, C, T, ... all_preds.append(out) return torch.cat(all_preds) diff --git a/autoemulate/experimental/emulators/the_well.py b/autoemulate/experimental/emulators/the_well.py index f4ce45b9d..da8d4e413 100644 --- a/autoemulate/experimental/emulators/the_well.py +++ b/autoemulate/experimental/emulators/the_well.py @@ -1,8 +1,9 @@ +import inspect import os from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path -from typing import ClassVar +from typing import Any, ClassVar import torch import torchinfo @@ -10,14 +11,18 @@ from autoemulate.core.device import TorchDeviceMixin, get_torch_device from autoemulate.core.types import DeviceLike, ModelParams, TensorLike from autoemulate.experimental.emulators.spatiotemporal import SpatioTemporalEmulator +from einops import rearrange from the_well.benchmark import models from the_well.benchmark.metrics import validation_metric_suite +from the_well.benchmark.models.common import BaseModel from the_well.benchmark.trainer import Trainer from the_well.data import DeltaWellDataset, WellDataModule from the_well.data.data_formatter import AbstractDataFormatter from the_well.data.datamodule import AbstractDataModule +from the_well.data.datasets import WellMetadata from torch import nn -from torch.optim.lr_scheduler import _LRScheduler +from torch.nn import functional as F +from torch.optim.lr_scheduler import LRScheduler, _LRScheduler from torch.utils.data import DataLoader @@ -34,7 +39,11 @@ class TrainerParams: max_rollout_steps: int = 10 short_validation_length: int = 20 make_rollout_videos: bool = True - lr_scheduler: type[torch.optim.lr_scheduler._LRScheduler] | None = None + lr_scheduler: ( + type[torch.optim.lr_scheduler._LRScheduler] + | Callable[[Any], torch.optim.lr_scheduler.LRScheduler] + | None + ) = None amp_type: str = "float16" # bfloat not supported in FFT num_time_intervals: int = 5 enable_amp: bool = False @@ -42,6 +51,19 @@ class TrainerParams: checkpoint_path: str = "" # Path to a checkpoint to resume from, if any device: DeviceLike = "cpu" output_path: str = "./" + # Enable scheduled teacher forcing in training + enable_tf_schedule: bool = False + # start, end, schedule_epochs, schedule_type, mode, min_prob + tf_params: dict = field( + default_factory=lambda: { + "start": 1.0, + "end": 0.0, + "schedule_epochs": None, # fallback to total epochs + "schedule_type": "linear", # linear | exponential | step + "mode": "mix", # mix (tempering) | prob (stochastic) + "min_prob": 1e-6, + } + ) class AutoEmulateTrainer(Trainer): @@ -55,7 +77,7 @@ def __init__( loss_fn: Callable, datamodule: WellDataModule, optimizer: torch.optim.Optimizer, - lr_scheduler: _LRScheduler | None, + lr_scheduler: _LRScheduler | LRScheduler | None, trainer_params: TrainerParams, ): """Subclass to integrate with AutoEmulate framework and extend functionality. @@ -99,6 +121,9 @@ def __init__( self.lr_scheduler = lr_scheduler self.loss_fn = loss_fn + # Flag controlling whether TF schedule is active + self.enable_tf_schedule = trainer_params.enable_tf_schedule + # Remaining trainer params self.is_delta = isinstance(datamodule.train_dataset, DeltaWellDataset) self.validation_suite = [*validation_metric_suite, self.loss_fn] @@ -115,7 +140,8 @@ def __init__( torch.bfloat16 if trainer_params.amp_type == "bfloat16" else torch.float16 ) self.grad_scaler = torch.GradScaler( - self.device.type, enabled=self.enable_amp and self.amp_type != "bfloat16" + self.device.type, + enabled=self.enable_amp and trainer_params.amp_type != "bfloat16", ) self.is_distributed = trainer_params.is_distributed self.best_val_loss = None @@ -124,9 +150,207 @@ def __init__( if self.datamodule.train_dataset.use_normalization: self.dset_norm = self.datamodule.train_dataset.norm self.formatter = formatter_cls(self.dset_metadata) - if len(trainer_params.checkpoint_path) > 0: + if ( + trainer_params.checkpoint_path is not None + and len(trainer_params.checkpoint_path) > 0 + ): self.load_checkpoint(trainer_params.checkpoint_path) + # Teacher Forcing scheduler setup + tf_params = trainer_params.tf_params + self.tf_start = float(tf_params.get("start", 1.0)) + self.tf_end = float(tf_params.get("end", 0.0)) + self.tf_mode = tf_params.get("mode", "mix") + self.tf_type = tf_params.get("schedule_type", "linear") + self.tf_epochs = tf_params.get("schedule_epochs") or self.max_epoch + self.tf_min_prob = float(tf_params.get("min_prob", 1e-6)) + + # Initialize current_epoch so scheduling logic has a defined value pre-training + self.current_epoch = self.starting_epoch + + def train_one_epoch(self, epoch: int, dataloader) -> float: + """Override to expose the current epoch to scheduling utilities. + + Sets `self.current_epoch` before delegating to the base implementation so + `_teacher_forcing_ratio` can derive a reliable epoch index without + re-implementing the full training loop from the upstream Trainer. + """ + self.current_epoch = epoch + + # Defer to base trainer to perform training and collect logs + result = super().train_one_epoch(epoch, dataloader) + if isinstance(result, tuple) and len(result) == 2: + epoch_loss, train_logs = result + else: # Upstream type hint mismatch safeguard + epoch_loss, train_logs = result, {} + + # Augment logs with scheduled TF ratio (0 if disabled) + try: + train_logs["tf/ratio_scheduled"] = float(self._teacher_forcing_ratio()) + # pragma: no cover - defensive + except Exception: + train_logs["tf/ratio_scheduled"] = float("nan") + + return epoch_loss, train_logs # type: ignore as this is the upstream signature + + def _teacher_forcing_ratio(self) -> float: + """Compute the scheduled teacher forcing ratio for the current epoch. + + Returns 0.0 immediately if scheduling is disabled. + """ + if not self.enable_tf_schedule: + return 0.0 + e = max(int(self.current_epoch) - 1, 0) + total = max(self.tf_epochs - 1, 1) + progress = min(e / total, 1.0) + # Base linear interpolation used as default + linear_val = self.tf_start + (self.tf_end - self.tf_start) * progress + + if self.tf_type == "exponential": + # Solve start * gamma^e = end at e=total => gamma = (end/start)^(1/total) + if self.tf_start > 0 and self.tf_end > 0: + gamma = (self.tf_end / self.tf_start) ** (1 / total) + r = self.tf_start * (gamma**e) + else: # fall back to linear if invalid bounds + r = linear_val + elif self.tf_type == "step": + r = self.tf_start if e < total else self.tf_end + else: # linear or unknown -> use linear interpolation + r = linear_val + # Clamp to [0.0, 1.0] numerical floor + return float(max(min(r, 1.0), 0.0)) + + def rollout_model( + self, + model, + batch, + formatter, + train: bool = True, + teacher_forcing: bool = False, + tf_ratio: float | None = None, + ): + """Roll out the model. + + Parameters + ---------- + model: nn.Module + The model to evaluate. + batch: dict + Batch produced by the dataloader (contains `input_fields` and optional + `constant_fields` plus targets). + formatter: AbstractDataFormatter + Formatter that converts batch tensors to model inputs / outputs. + train: bool + If True, uses the scheduled teacher forcing ratio (unless `tf_ratio` is + explicitly provided). If False, no schedule is applied. Defaults to True. + teacher_forcing: bool + Backwards-compatible flag. When evaluating (`train=False`) and `tf_ratio` is + not provided, `teacher_forcing=True` implies full teacher forcing + (ratio = 1.0). During training this flag is ignored because scheduled + teacher forcing is always active. Defaults to False. + tf_ratio: float | None + Explicit teacher forcing ratio override in `[0, 1]`. Highest precedence: if + provided it is used directly (clamped) regardless of mode (train/eval) or + schedule. This allows ad-hoc evaluation at a fixed ratio or reproducing the + original full teacher forcing rollout with `tf_ratio=1.0`. Defaults to None. + + Notes + ----- + Precedence: + 1. `tf_ratio` argument (if not None) + 2. Training schedule (`train=True`) + 3. Full TF on eval if `teacher_forcing=True` + 4. No teacher forcing + """ + inputs, y_ref = formatter.process_input(batch) + rollout_steps = min( + y_ref.shape[1], self.max_rollout_steps + ) # Number of timesteps in target + y_ref = y_ref[:, :rollout_steps].to(self.device) + # Create a moving batch of one step at a time + moving_batch = batch + moving_batch["input_fields"] = moving_batch["input_fields"].to(self.device) + if "constant_fields" in moving_batch: + moving_batch["constant_fields"] = moving_batch["constant_fields"].to( + self.device + ) + y_preds = [] + + # Calculate the tf_ratio using precedence rules + def _resolve_tf_ratio(): + if tf_ratio is not None: + return float(max(min(tf_ratio, 1.0), 0.0)) + if train and self.enable_tf_schedule: + return self._teacher_forcing_ratio() + if teacher_forcing: + return 1.0 + return 0.0 + + effective_tf_ratio = _resolve_tf_ratio() + + use_tf = effective_tf_ratio > 0.0 + + for i in range(rollout_steps): + if not train: + moving_batch = self.normalize(moving_batch) + + inputs, _ = formatter.process_input(moving_batch) + inputs = [x.to(self.device) for x in inputs] + y_pred = model(*inputs) + + y_pred = formatter.process_output_channel_last(y_pred) + + if not train: + moving_batch, y_pred = self.denormalize(moving_batch, y_pred) + + if (not train) and self.is_delta: + # TODO: update to handle case when more than single time step + assert { + moving_batch["input_fields"][:, -1, ...].shape == y_pred.shape + }, ( + f"Mismatching shapes between last input timestep " + f"{moving_batch[:, -1, ...].shape} and prediction {y_pred.shape}" + ) + y_pred = moving_batch["input_fields"][:, -1, ...] + y_pred + y_pred = formatter.process_output_expand_time(y_pred) + # If not last step, update moving batch + if i != rollout_steps - 1: + next_moving_batch_tail = moving_batch["input_fields"][:, 1:] + if use_tf: + if self.tf_mode == "prob": + # Sample Bernoulli per batch element deciding to use GT vs pred + mask = ( + torch.rand( + y_pred.shape[0], # generate a mask per batch + 1, + *([1] * (y_pred.dim() - 2)), + device=y_pred.device, + ) + < effective_tf_ratio + ).to(y_pred.dtype) + mixed = mask * y_ref[:, i : i + 1] + (1 - mask) * y_pred + next_moving_batch = torch.cat( + [next_moving_batch_tail, mixed], dim=1 + ) + else: # mix/tempering + mixed = ( + effective_tf_ratio * y_ref[:, i : i + 1] + + (1 - effective_tf_ratio) * y_pred + ) + next_moving_batch = torch.cat( + [next_moving_batch_tail, mixed], dim=1 + ) + else: + # Fully free running + next_moving_batch = torch.cat( + [next_moving_batch_tail, y_pred], dim=1 + ) + moving_batch["input_fields"] = next_moving_batch + y_preds.append(y_pred) + y_pred_out = torch.cat(y_preds, dim=1) + y_ref = y_ref.to(self.device) + return y_pred_out, y_ref + class TheWellEmulator(SpatioTemporalEmulator): """Base class for The Well emulators.""" @@ -134,6 +358,8 @@ class TheWellEmulator(SpatioTemporalEmulator): model: torch.nn.Module model_cls: type[torch.nn.Module] model_parameters: ClassVar[ModelParams] + with_time: bool = False + trainer: AutoEmulateTrainer def __init__( self, @@ -150,7 +376,28 @@ def __init__( TorchDeviceMixin.__init__( self, device=get_torch_device(self.trainer_params.device) ) - super().__init__(**kwargs) + # Init base without nn.Module kwargs + super().__init__() + + # Split incoming kwargs into those intended for the model class vs others. + # Anything matching the model's __init__ signature (excluding self) is + # treated as a model override and merged with class-level model_parameters. + model_sig = inspect.signature(self.model_cls.__init__).parameters + allowed_model_keys = { + name + for name, p in model_sig.items() + if name != "self" + and p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + provided_model_kwargs = { + k: kwargs.pop(k) for k in list(kwargs.keys()) if k in allowed_model_keys + } + # Combine defaults with provided overrides + self._model_kwargs = {**self.model_parameters, **provided_model_kwargs} # Set output path output_path = Path(self.trainer_params.output_path) @@ -163,13 +410,18 @@ def __init__( metadata = datamodule.train_dataset.metadata self.n_steps_input = datamodule.train_dataset.n_steps_input self.n_steps_output = datamodule.train_dataset.n_steps_output - # TODO: aim to be more flexible (such as time as an input channel) - self.n_input_fields = ( - self.n_steps_input * metadata.n_fields + metadata.n_constant_fields - ) + + # Determine whether the model expects time as an explicit dimension + # For such models, channels should repr fields(+constants), not time*fields + if self.with_time: + self.n_input_fields = metadata.n_fields + metadata.n_constant_fields + else: + self.n_input_fields = ( + self.n_steps_input * metadata.n_fields + metadata.n_constant_fields + ) self.n_output_fields = metadata.n_fields self.model = self.model_cls( - **self.model_parameters, + **self._model_kwargs, # TODO: check if general beyond FNO dim_in=self.n_input_fields, dim_out=self.n_output_fields, @@ -260,6 +512,51 @@ def is_multioutput() -> bool: return True +class FNOWithTime(BaseModel): + """FNO with time.""" + + def __init__( + self, + dim_in: int, + dim_out: int, + n_spatial_dims: int, + spatial_resolution: tuple[int, ...], + modes_time: int, + modes1: int, + modes2: int, + modes3: int = 16, + hidden_channels: int = 64, + gradient_checkpointing: bool = False, + ): + super().__init__(n_spatial_dims, spatial_resolution) + self.dim_in = dim_in + self.dim_out = dim_out + self.modes_time = modes_time + self.modes1 = modes1 + self.modes2 = modes2 + self.modes3 = modes3 + self.hidden_channels = hidden_channels + self.model = None + self.initialized = False + self.gradient_checkpointing = gradient_checkpointing + + if self.n_spatial_dims == 2: + self.n_modes = (self.modes_time, self.modes1, self.modes2) + elif self.n_spatial_dims == 3: + self.n_modes = (self.modes_time, self.modes1, self.modes2, self.modes3) + + self.model = models.fno.NeuralOpsCheckpointWrapper( + n_modes=self.n_modes, + in_channels=self.dim_in, + out_channels=self.dim_out, + hidden_channels=self.hidden_channels, + gradient_checkpointing=gradient_checkpointing, + ) + + def forward(self, input) -> torch.Tensor: # noqa: D102 + return self.model(input) # type: ignore # noqa: PGH003 + + class TheWellFNO(TheWellEmulator): """The Well FNO emulator.""" @@ -273,6 +570,171 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class TheWellFNOWithTime(TheWellEmulator): + """The Well FNO emulator.""" + + with_time: bool = True + model_cls: type[torch.nn.Module] = FNOWithTime + model_parameters: ClassVar[ModelParams] = { + "modes_time": 3, + "modes1": 16, + "modes2": 16, + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DefaultChannelsFirstFormatterWithTime(AbstractDataFormatter): + """ + Default preprocessor for data in channels first format. + + Stacks time as individual channel. + """ + + def process_input(self, data: dict) -> tuple: # noqa: D102 + x = data["input_fields"] + x = rearrange(x, "b ... c -> b c ...") # Move channels to before batch + if "constant_fields" in data: + flat_constants = rearrange(data["constant_fields"], "b ... c -> b c 1 ...") + x = torch.cat( + [ + x, + flat_constants, + ], + dim=1, + ) + y = data["output_fields"] + # TODO - Add warning to output if nan has to be replaced + # in some cases (staircase), its ok. In others, it's not. + return (torch.nan_to_num(x),), torch.nan_to_num(y) + + def process_output_channel_last(self, output: torch.Tensor) -> torch.Tensor: # noqa: D102 + return rearrange(output, "b c ... -> b ... c") + + def process_output_expand_time(self, output: torch.Tensor) -> torch.Tensor: # noqa: D102 + # Time does not need to be expanded as it is already included + output = rearrange(output, "b ... c -> b ... c") + # Only take the first temporal slice at the moment since predictions are only + # for one step ahead + return output[:, :1, ...] + + +class TheWellFNOWithLearnableWeights(TheWellEmulator): + """The Well FNO emulator with learnable weights.""" + + model_cls: type[torch.nn.Module] = models.FNO + model_parameters: ClassVar[ModelParams] = { + "modes1": 16, + "modes2": 16, + } + + def __init__( + self, + datamodule: AbstractDataModule | WellDataModule, + formatter_cls: type[AbstractDataFormatter], + loss_fn: Callable, + trainer_params: TrainerParams | None = None, + **kwargs, + ): + # Parameters for the Trainer + self.trainer_params = trainer_params or TrainerParams() + + # Device setup and backend init + TorchDeviceMixin.__init__( + self, device=get_torch_device(self.trainer_params.device) + ) + # Skip TheWellEmulator init as overriden here + SpatioTemporalEmulator.__init__(self, **kwargs) + + # Set output path + output_path = Path(self.trainer_params.output_path) + + # Set datamodule + self.datamodule = datamodule + + if isinstance(datamodule, WellDataModule): + # Load metadata from train_dataset + metadata = datamodule.train_dataset.metadata + self.n_steps_input = datamodule.train_dataset.n_steps_input + self.n_steps_output = datamodule.train_dataset.n_steps_output + # TODO: aim to be more flexible (such as time as an input channel) + self.n_input_fields = ( + self.n_steps_input * metadata.n_fields + metadata.n_constant_fields + ) + self.n_output_fields = metadata.n_fields + self.model = self.model_cls( + **self.model_parameters, + # TODO: check if general beyond FNO + dim_in=self.n_input_fields, + dim_out=self.n_output_fields, + n_spatial_dims=metadata.n_spatial_dims, + spatial_resolution=metadata.spatial_resolution, + ) + # TODO: update with logging + print(torchinfo.summary(self.model, depth=5)) + else: + msg = "Alternative datamodules not yet supported" + raise NotImplementedError(msg) + + # Init optimizer + optimizer = self.trainer_params.optimizer_cls( + self.model.parameters(), **self.trainer_params.optimizer_params + ) + + # Init scheduler + lr_scheduler = ( + self.trainer_params.lr_scheduler(optimizer) + if self.trainer_params.lr_scheduler is not None + else None + ) + + # Learnable weights for loss function + self.weights = nn.Parameter( + torch.ones(self.n_steps_output, device=self.device), requires_grad=True + ) + + # Assign given metric as base loss function + self.base_loss_func = loss_fn + + # Init trainer + self.trainer = AutoEmulateTrainer( + loss_fn=self.custom_loss_fn, # use custom loss fn as callable in trainer + output_path=output_path, + formatter_cls=formatter_cls, + model=self.model, + datamodule=datamodule, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + trainer_params=self.trainer_params, + ) + + # Move to device + # TODO: check if this needs updating for distributed handling + self.to(self.device) + + def custom_loss_fn(self, y_pred, y_target, meta: WellMetadata): + """Loss function that uses parameters constructed at init.""" + # Make positive + w = F.softplus(self.weights) + + # Normalize so mean(w) == 1 + w = w / (w.mean() + 1e-12) + + # Reshape to (1, n_steps, spatial_dims, channels) for broadcasting + w = w.view(1, -1, *([1] * meta.n_spatial_dims), 1) + + # Pad with 1s along the time dimension if needed to match y_pred shape + if w.shape[1] != y_pred.shape[1]: + extra = y_pred.shape[1] - w.shape[1] + pad_shape = (1, extra, *([1] * meta.n_spatial_dims), 1) + ones = torch.ones(pad_shape, device=w.device, dtype=w.dtype) + w = torch.cat([w, ones], dim=1) + + # Return a weighted loss + return self.base_loss_func(w * y_pred, w * y_target, meta) + + # TODO: fix this as not initializing correctly at the moment class TheWellAFNO(TheWellEmulator): """The Well AFNO emulator.""" diff --git a/autoemulate/experimental/exploratory/reaction_diffusion_fno.ipynb b/autoemulate/experimental/exploratory/reaction_diffusion_fno.ipynb index 4ecaa66de..f063332a3 100644 --- a/autoemulate/experimental/exploratory/reaction_diffusion_fno.ipynb +++ b/autoemulate/experimental/exploratory/reaction_diffusion_fno.ipynb @@ -22,7 +22,7 @@ "metadata": {}, "outputs": [], "source": [ - "data.keys()\n" + "data.keys()" ] }, { @@ -46,9 +46,9 @@ "source": [ "from torch.utils.data import DataLoader\n", "\n", - "from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataset\n", + "from autoemulate.experimental.data.spatiotemporal_dataset import ReactionDiffusionDataset\n", "\n", - "dataset = AutoEmulateDataset(data_path=None, data=data, n_steps_input=1, n_steps_output=1)\n", + "dataset = ReactionDiffusionDataset(data_path=None, data=data, n_steps_input=1, n_steps_output=1)\n", "batch_orig = next(iter(DataLoader(dataset)))" ] }, @@ -70,8 +70,8 @@ "outputs": [], "source": [ "# for autoregressive prediction , we need to split at trajectory level \n", - "\n", "from torch.utils.data import DataLoader\n", + "\n", "# Split at trajectory level\n", "n_trajectories = dataset.n_trajectories\n", "train_traj_count = int(0.9 * n_trajectories)\n", @@ -89,16 +89,6 @@ "id": "6", "metadata": {}, "outputs": [], - "source": [ - "train_traj_idxs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], "source": [ "# Create train data\n", "train_data = {\n", @@ -118,18 +108,18 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "7", "metadata": {}, "outputs": [], "source": [ - "train_dataset = AutoEmulateDataset(\n", + "train_dataset = ReactionDiffusionDataset(\n", " data_path=None, \n", " data=train_data, \n", " n_steps_input=1, \n", " n_steps_output=1\n", ")\n", "\n", - "val_dataset = AutoEmulateDataset(\n", + "val_dataset = ReactionDiffusionDataset(\n", " data_path=None, \n", " data=val_data, \n", " n_steps_input=1, \n", @@ -140,19 +130,19 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "8", "metadata": {}, "outputs": [], "source": [ - "train_loader = DataLoader(train_dataset)\n", - "val_loader = DataLoader(val_dataset)\n", + "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)\n", "batch = next(iter(train_loader))" ] }, { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -172,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -184,23 +174,31 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "12", "metadata": {}, "outputs": [], "source": [ + "from functools import partial\n", "from autoemulate.experimental.emulators.fno import FNOEmulator\n", - "\n", + "from the_well.benchmark.metrics import VRMSE, MSE\n", + "with_constants, with_time = True, False\n", "emulator = FNOEmulator(\n", - " n_modes=(1, 16, 16),\n", + " n_modes=(16, 16),\n", " hidden_channels=16,\n", - " in_channels=3,\n", - " out_channels=1,)\n" + " channels=(0,),\n", + " loss_fn=VRMSE(),\n", + " with_constants=with_constants,\n", + " with_time=with_time,\n", + " metadata=dataset.metadata,\n", + " in_channels=1 + dataset.constant_scalars.shape[1] if with_constants else 1, # type: ignore\n", + " out_channels=1\n", + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -211,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -223,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -233,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -278,7 +276,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -290,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -302,7 +300,7 @@ "y_true = torch.cat(\n", " [\n", " prepare_batch(\n", - " batch, channels=(0,), with_constants=True, with_time=True\n", + " batch, channels=(0,), with_constants=with_constants, with_time=with_time\n", " )[1]\n", " for batch in DataLoader(dataset)\n", " ],\n", @@ -313,20 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "# TODO: fix with autoregressive prediction\n", - "# from torchmetrics import R2Score\n", - "\n", - "# R2Score()(y_pred.reshape(-1).detach(), y_true.reshape(-1).detach()).item()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", + "id": "20", "metadata": {}, "outputs": [], "source": [] diff --git a/autoemulate/experimental/exploratory/the_well/advection_diffusion.ipynb b/autoemulate/experimental/exploratory/the_well/advection_diffusion.ipynb index dba368b7b..df3941b12 100644 --- a/autoemulate/experimental/exploratory/the_well/advection_diffusion.ipynb +++ b/autoemulate/experimental/exploratory/the_well/advection_diffusion.ipynb @@ -7,16 +7,10 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", - "from the_well.data import WellDataModule\n", "import matplotlib.pyplot as plt\n", "import logging\n", "from autoemulate.experimental.emulators.the_well import TheWellFNO\n", - "\n", - "logging.basicConfig(\n", - " level=logging.INFO,\n", - " format=\"%(asctime)s %(name)s %(levelname)s: %(message)s\",\n", - ")" + "from autoemulate.experimental.data.spatiotemporal_dataset import AdvectionDiffusionDataset, AutoEmulateDataModule" ] }, { @@ -29,10 +23,11 @@ "# Make an autoemulate datamodule from the_well datamodule\n", "from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n", "rd = AdvectionDiffusion(n=64, T=10, dt=0.1, return_timeseries=True)\n", - "data = rd.forward_samples_spatiotemporal(6)\n", + "data = rd.forward_samples_spatiotemporal(200)\n", + "# data = rd.forward_samples_spatiotemporal(20)\n", "y = data[\"data\"]\n", - "data_valid = rd.forward_samples_spatiotemporal(2)\n", - "data_test = rd.forward_samples_spatiotemporal(2)" + "data_valid = rd.forward_samples_spatiotemporal(4)\n", + "data_test = rd.forward_samples_spatiotemporal(4)" ] }, { @@ -42,17 +37,17 @@ "metadata": {}, "outputs": [], "source": [ - "from autoemulate.experimental.data.spatiotemporal_dataset import AdvectionDiffusionDataset, AutoEmulateDataModule\n", + "logging.basicConfig(level=logging.INFO)\n", "\n", "ae_data_module = AutoEmulateDataModule(\n", " n_steps_input=4,\n", - " n_steps_output=1,\n", + " n_steps_output=10,\n", " data_path=None,\n", " dataset_cls=AdvectionDiffusionDataset,\n", " data={\"train\": data, \"valid\": data_valid, \"test\": data_test},\n", " verbose=False\n", ")\n", - "output_path = \"../data/the_well/runs/advection_diffusion_wip\"" + "output_path = \"../data/the_well/runs/advection_diffusion_wip_fix_2\"" ] }, { @@ -78,8 +73,22 @@ "# Initialize the emulator\n", "from the_well.data.data_formatter import DefaultChannelsFirstFormatter\n", "from the_well.benchmark.metrics import VRMSE\n", - "from autoemulate.experimental.emulators.the_well import TheWellAFNO, TrainerParams\n", + "from autoemulate.experimental.emulators.the_well import TheWellAFNO, TheWellFNOWithTime, TrainerParams, TheWellFNOWithLearnableWeights, DefaultChannelsFirstFormatterWithTime\n", + "\n", + "# FNO with learnable weights\n", + "# em = TheWellFNOWithLearnableWeights(\n", + "# datamodule=ae_data_module,\n", + "# formatter_cls=DefaultChannelsFirstFormatter,\n", + "# loss_fn=VRMSE(),\n", + "# trainer_params=TrainerParams(\n", + "# device=\"mps\",\n", + "# output_path=output_path,\n", + "# max_rollout_steps=100,\n", + "# optimizer_params={\"lr\": 1e-5}\n", + "# )\n", + "# )\n", "\n", + "# FNO with teacher forcing\n", "em = TheWellFNO(\n", " datamodule=ae_data_module,\n", " formatter_cls=DefaultChannelsFirstFormatter,\n", @@ -88,7 +97,24 @@ " device=\"mps\",\n", " output_path=output_path,\n", " max_rollout_steps=100,\n", - " optimizer_params={\"lr\": 1e-3}\n", + " optimizer_params={\"lr\": 1e-3},\n", + " enable_tf_schedule=True,\n", + " )\n", + ")\n", + "\n", + "# FNO with time\n", + "em = TheWellFNOWithTime(\n", + " datamodule=ae_data_module,\n", + " formatter_cls=DefaultChannelsFirstFormatterWithTime,\n", + " loss_fn=VRMSE(),\n", + " modes_time=2,\n", + " modes1=16,\n", + " modes2=16,\n", + " trainer_params=TrainerParams(\n", + " device=\"mps\",\n", + " output_path=output_path,\n", + " max_rollout_steps=100,\n", + " optimizer_params={\"lr\": 1e-2}\n", " )\n", ")" ] @@ -99,6 +125,32 @@ "id": "5", "metadata": {}, "outputs": [], + "source": [ + "# Example of different rollouts\n", + "y_true = batch[\"output_fields\"].to(em.device)\n", + "y_teacher_forcing = em.trainer.rollout_model(em.model, batch, em.trainer.formatter, train=False, teacher_forcing=True)[0]\n", + "y_free_running = em.trainer.rollout_model(em.model, batch, em.trainer.formatter, train=False, teacher_forcing=False)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# Explore passing predictions to loss\n", + "# vrmse = VRMSE()\n", + "# print(\"Teacher forcing VRMSE:\", vrmse(y_teacher_forcing, y_true, ae_data_module.train_dataset.metadata).detach().cpu())\n", + "# print(\"Free running VRMSE:\", vrmse(y_free_running, y_true, ae_data_module.train_dataset.metadata).detach().cpu()[:1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], "source": [ "# Fit the model\n", "em.fit()" @@ -107,7 +159,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -122,7 +174,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -136,7 +188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +199,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -157,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -168,7 +220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +245,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -203,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "16", "metadata": {}, "outputs": [], "source": [ diff --git a/autoemulate/experimental/exploratory/the_well/boutpp.ipynb b/autoemulate/experimental/exploratory/the_well/boutpp.ipynb index c771213d2..81fd5f41c 100644 --- a/autoemulate/experimental/exploratory/the_well/boutpp.ipynb +++ b/autoemulate/experimental/exploratory/the_well/boutpp.ipynb @@ -33,12 +33,12 @@ "# ../data/bout/test/data.pt\n", "ae_data_module = AutoEmulateDataModule(\n", " n_steps_input=5,\n", - " n_steps_output=1,\n", + " n_steps_output=5,\n", " data_path=\"../data/bout/\",\n", " dataset_cls=BOUTDataset,\n", " verbose=False\n", ")\n", - "output_path = \"../data/the_well/runs/bout_wip\"" + "output_path = \"../data/the_well/runs/bout_wip_new\"" ] }, { @@ -66,15 +66,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Create a function to animate simulation images\n", - "it = iter(ds)\n", - "traj = next(it)[\"input_fields\"]\n", - "_ = make_video(\n", - " traj,\n", - " traj,\n", - " output_dir=output_path,\n", - " metadata=ds.metadata,\n", - ")" + "# # Create a function to animate simulation images\n", + "# it = iter(ds)\n", + "# traj = next(it)[\"input_fields\"]\n", + "# _ = make_video(\n", + "# traj,\n", + "# traj,\n", + "# output_dir=output_path,\n", + "# metadata=ds.metadata,\n", + "# )" ] }, { @@ -98,7 +98,7 @@ "outputs": [], "source": [ "# Initialize the emulator\n", - "from the_well.benchmark.metrics import VRMSE\n", + "from the_well.benchmark.metrics import VRMSE, RMSE\n", "from the_well.data.data_formatter import DefaultChannelsFirstFormatter, DefaultChannelsLastFormatter\n", "from autoemulate.experimental.emulators.the_well import TheWellUNetClassic, TheWellUNetConvNext, TrainerParams\n", "\n", @@ -108,12 +108,15 @@ " formatter_cls=DefaultChannelsFirstFormatter,\n", " datamodule=ae_data_module,\n", " loss_fn=VRMSE(),\n", + " # loss_fn=RMSE(),\n", " trainer_params=TrainerParams(\n", " max_rollout_steps=100,\n", " output_path=output_path,\n", " device=\"mps\",\n", - " optimizer_params={\"lr\": 1e-3}\n", - " )\n", + " optimizer_params={\"lr\": 1e-4},\n", + " enable_tf_schedule=True\n", + "\n", + " ),\n", ")" ] }, @@ -199,6 +202,14 @@ "# Run prediction from a non-rollout dataloader\n", "em.predict(ae_data_module.test_dataloader()).shape\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/autoemulate/experimental/exploratory/the_well/reaction_diffusion.ipynb b/autoemulate/experimental/exploratory/the_well/reaction_diffusion.ipynb index 1128bcc95..03ec0fe85 100644 --- a/autoemulate/experimental/exploratory/the_well/reaction_diffusion.ipynb +++ b/autoemulate/experimental/exploratory/the_well/reaction_diffusion.ipynb @@ -9,13 +9,7 @@ "source": [ "import matplotlib.pyplot as plt\n", "import logging\n", - "from autoemulate.experimental.emulators.the_well import TheWellFNO\n", - "\n", - "logging.basicConfig(\n", - " level=logging.INFO,\n", - " format=\"%(asctime)s %(name)s %(levelname)s: %(message)s\",\n", - ")\n", - "logger = logging.getLogger(\"the_well\")" + "from autoemulate.experimental.emulators.the_well import TheWellFNO, TheWellFNOWithLearnableWeights" ] }, { @@ -28,10 +22,10 @@ "# Make an autoemulate datamodule from the_well datamodule\n", "from autoemulate.simulations.reaction_diffusion import ReactionDiffusion\n", "rd = ReactionDiffusion(n=32, T=10, dt=0.1, return_timeseries=True)\n", - "data = rd.forward_samples_spatiotemporal(3)\n", + "data = rd.forward_samples_spatiotemporal(20)\n", "y = data[\"data\"]\n", - "data_valid = rd.forward_samples_spatiotemporal(1)\n", - "data_test = rd.forward_samples_spatiotemporal(1)" + "data_valid = rd.forward_samples_spatiotemporal(4)\n", + "data_test = rd.forward_samples_spatiotemporal(4)" ] }, { @@ -43,11 +37,15 @@ "source": [ "from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataModule, ReactionDiffusionDataset\n", "\n", + "logging.basicConfig(level=logging.INFO)\n", + "\n", "ae_data_module = AutoEmulateDataModule(\n", + " n_steps_input=5, \n", + " n_steps_output=10,\n", " data_path=None,\n", " dataset_cls=ReactionDiffusionDataset,\n", " data={\"train\": data, \"valid\": data_valid, \"test\": data_test},\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -71,23 +69,33 @@ "metadata": {}, "outputs": [], "source": [ + "from autoemulate.experimental.emulators.the_well import (\n", + " DefaultChannelsFirstFormatterWithTime, TrainerParams, TheWellFNOWithTime\n", + ")\n", "from the_well.data.data_formatter import DefaultChannelsFirstFormatter\n", - "from the_well.benchmark.metrics import VRMSE\n", + "from the_well.benchmark.metrics import VRMSE, RMSE\n", "\n", - "from autoemulate.experimental.emulators.the_well import TrainerParams\n", + "# from autoemulate.experimental.emulators.the_well import (\n", + "# DefaultChannelsFirstFormatterWithTime, TrainerParams, TheWellFNOWithTime\n", + "# )\n", " \n", "# Device set to MPS as example, can also be \"cpu\", \"cuda\" etc\n", "device = \"mps\" # \"cpu\"\n", - "output_path = \"../data/the_well/runs/reaction_diffusion_wip\"\n", + "output_path = \"../data/the_well/runs/reaction_diffusion_wip_lw_new\"\n", "\n", "# Initialize the emulator\n", - "em = TheWellFNO(\n", - " formatter_cls=DefaultChannelsFirstFormatter,\n", - " loss_fn=VRMSE(),\n", + "# em = TheWellFNOWithLearnableWeights(\n", + "em = TheWellFNOWithTime(\n", + " # formatter_cls=DefaultChannelsFirstFormatter,\n", + " formatter_cls=DefaultChannelsFirstFormatterWithTime,\n", + " # loss_fn=VRMSE(),\n", + " loss_fn=RMSE(),\n", " datamodule=ae_data_module,\n", " trainer_params=TrainerParams(\n", " output_path=output_path,\n", - " device=device\n", + " max_rollout_steps=100,\n", + " device=device,\n", + " optimizer_params={\"lr\": 1e-3},\n", " )\n", ")\n" ] diff --git a/autoemulate/experimental/exploratory/the_well/turbulent_radiative_layer_2D.ipynb b/autoemulate/experimental/exploratory/the_well/turbulent_radiative_layer_2D.ipynb new file mode 100644 index 000000000..cb1bab87c --- /dev/null +++ b/autoemulate/experimental/exploratory/the_well/turbulent_radiative_layer_2D.ipynb @@ -0,0 +1,181 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import logging\n", + "from autoemulate.experimental.emulators.the_well import TheWellFNO, TheWellFNOWithLearnableWeights\n", + "from pathlib import Path\n", + "from the_well.data import WellDataModule, WellDataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Make a datamodule\n", + "logging.basicConfig(level=logging.INFO)\n", + "n_steps_input = 4\n", + "n_steps_output = 1\n", + "well_dataset_name=\"turbulent_radiative_layer_2D\"\n", + "ae_data_module = WellDataModule(\n", + " well_base_path=\"../data/the_well/datasets\",\n", + " well_dataset_name=well_dataset_name,\n", + " n_steps_input=n_steps_input,\n", + " n_steps_output=n_steps_output,\n", + " batch_size=4,\n", + " train_dataset=WellDataset,\n", + ")\n", + "output_path = Path(\"../data/the_well/runs\") / f\"{well_dataset_name}_fno\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot example\n", + "batch = next(iter(ae_data_module.val_dataloader()))\n", + "plt.imshow(batch[\"input_fields\"][0, 0, :, :, 0])\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from autoemulate.experimental.emulators.the_well import (\n", + " DefaultChannelsFirstFormatterWithTime, TrainerParams, TheWellFNOWithTime\n", + ")\n", + "from the_well.data.data_formatter import DefaultChannelsFirstFormatter\n", + "from the_well.benchmark.metrics import VRMSE, RMSE\n", + "\n", + "# from autoemulate.experimental.emulators.the_well import (\n", + "# DefaultChannelsFirstFormatterWithTime, TrainerParams, TheWellFNOWithTime\n", + "# )\n", + " \n", + "# Device set to MPS as example, can also be \"cpu\", \"cuda\" etc\n", + "device = \"mps\" # \"cpu\"\n", + "\n", + "# Initialize the emulator\n", + "# em = TheWellFNOWithLearnableWeights(\n", + "em = TheWellFNO(\n", + " formatter_cls=DefaultChannelsFirstFormatter,\n", + " loss_fn=VRMSE(),\n", + " datamodule=ae_data_module,\n", + " trainer_params=TrainerParams(\n", + " output_path=str(output_path),\n", + " max_rollout_steps=100,\n", + " device=device,\n", + " optimizer_params={\"lr\": 1e-3},\n", + " )\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "# Fit the model\n", + "em.fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# Validation loop\n", + "valid_results = em.trainer.validation_loop(\n", + " ae_data_module.rollout_val_dataloader(),\n", + " valid_or_test=\"rollout_valid\",\n", + " full=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "test_results = em.trainer.validation_loop(\n", + " ae_data_module.rollout_test_dataloader(),\n", + " valid_or_test=\"rollout_test\",\n", + " full=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "pprint(valid_results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "pprint(test_results)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/autoemulate/experimental/run_the_well_experiment.py b/autoemulate/experimental/run_the_well_experiment.py new file mode 100644 index 000000000..938c2c8ed --- /dev/null +++ b/autoemulate/experimental/run_the_well_experiment.py @@ -0,0 +1,703 @@ +"""Script to run spatiotemporal emulation experiments from YAML config. + +This script provides a command-line interface for running The Well-based +spatiotemporal emulation experiments using YAML configuration files. + +Example usage: + # Create an example config + python run_the_well_experiment.py --create-example + + # Run an experiment from config + python run_the_well_experiment.py --config config.yaml + + # Override output directory + python run_the_well_experiment.py --config config.yaml --output-dir ./results +""" + +import argparse +import logging +import sys +from collections.abc import Callable +from datetime import datetime +from pathlib import Path +from typing import Any + +import h5py +import torch +from autoemulate.experimental.config_models import ( + DataConfig, + DatasetType, + DataSourceType, + EmulatorType, + ExperimentConfig, + FormatterType, + LossFunctionType, + LRSchedulerType, + ModelParamsConfig, + OptimizerType, + PathsConfig, + SimulatorConfig, + TrainerConfig, +) +from autoemulate.experimental.data.spatiotemporal_dataset import ( + AdvectionDiffusionDataset, + AutoEmulateDataModule, + BOUTDataset, + ReactionDiffusionDataset, +) +from autoemulate.experimental.emulators.the_well import ( + DefaultChannelsFirstFormatterWithTime, + TheWellAFNO, + TheWellEmulator, + TheWellFNO, + TheWellFNOWithLearnableWeights, + TheWellFNOWithTime, + TheWellUNetClassic, + TheWellUNetConvNext, + TrainerParams, +) +from autoemulate.simulations.advection_diffusion import AdvectionDiffusion +from autoemulate.simulations.reaction_diffusion import ReactionDiffusion +from the_well.benchmark.metrics import VRMSE +from the_well.data import WellDataModule, WellDataset +from the_well.data.data_formatter import DefaultChannelsFirstFormatter + +# Add project root to path if needed +project_root = Path(__file__).parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + + +# Set up logger +logger = logging.getLogger(__name__) + + +def setup_logging(output_dir: Path, log_level: str = "INFO", verbose: bool = False): + """Set up logging configuration. + + Parameters + ---------- + output_dir : Path + Directory to save log file + log_level : str + Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + verbose : bool + If True, also log to console with detailed format + """ + # Create logs directory + log_dir = output_dir / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + + # Create log filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = log_dir / f"experiment_{timestamp}.log" + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(getattr(logging, log_level.upper())) + + # Remove any existing handlers + root_logger.handlers = [] + + # File handler - always detailed + file_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(file_formatter) + root_logger.addHandler(file_handler) + + # Console handler + if verbose: + console_formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S", + ) + else: + console_formatter = logging.Formatter("%(message)s") + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(getattr(logging, log_level.upper())) + console_handler.setFormatter(console_formatter) + root_logger.addHandler(console_handler) + + msg = f"Logging initialized. Log file: {log_file}" + logger.info(msg) + + return log_file + + +def create_simulator(config: ExperimentConfig): + """Create simulator from configuration.""" + if config.simulator is None: + msg = "Simulator config required when generating data" + raise ValueError(msg) + + sim_cfg = config.simulator + + if sim_cfg.type.value == "advection_diffusion": + return AdvectionDiffusion( + parameters_range=sim_cfg.parameters_range, + output_names=sim_cfg.output_names, + return_timeseries=sim_cfg.return_timeseries, + n=sim_cfg.n, + L=sim_cfg.L, + T=sim_cfg.T, + dt=sim_cfg.dt, + ) + if sim_cfg.type.value == "reaction_diffusion": + return ReactionDiffusion( + parameters_range=sim_cfg.parameters_range, + output_names=sim_cfg.output_names, + return_timeseries=sim_cfg.return_timeseries, + n=sim_cfg.n, + L=sim_cfg.L, + T=sim_cfg.T, + dt=sim_cfg.dt, + ) + + msg = f"Unknown simulator type: {sim_cfg.type}" + raise ValueError(msg) + + +def get_dataset_class(dataset_type: DatasetType): + """Get dataset class from type.""" + dataset_classes = { + DatasetType.ADVECTION_DIFFUSION: AdvectionDiffusionDataset, + DatasetType.REACTION_DIFFUSION: ReactionDiffusionDataset, + DatasetType.BOUT: BOUTDataset, + # Add more dataset types as needed + } + + if dataset_type not in dataset_classes: + msg = f"Dataset type {dataset_type} not supported" + raise ValueError(msg) + + return dataset_classes[dataset_type] + + +def generate_or_load_data( + config: ExperimentConfig, simulator=None +) -> WellDataModule | AutoEmulateDataModule: + """Generate or load data and create data module. + + Supports three data source types: + 1. GENERATED: Generate data from simulator + 2. FILE: Load from existing files (HDF5/PyTorch) + 3. WELL_NATIVE: Use The Well's native datasets + """ + data_cfg = config.data + source_type = data_cfg.get_source_type() + + logger.info("Data source type: %s", source_type.value) + + # Get dtype + dtype = torch.float32 if data_cfg.dtype == "float32" else torch.float64 + + # Handle The Well native datasets + if source_type == DataSourceType.WELL_NATIVE: + if data_cfg.well_dataset_name is None: + msg = "well_dataset_name required for WELL_NATIVE source type" + raise ValueError(msg) + + logger.info( + "Creating WellDataModule for dataset: %s", data_cfg.well_dataset_name + ) + logger.info("Base path: %s", data_cfg.data_path or "../data/the_well/datasets") + logger.info( + "n_steps_input: %d, n_steps_output: %d", + data_cfg.n_steps_input, + data_cfg.n_steps_output, + ) + + datamodule = WellDataModule( + well_base_path=data_cfg.data_path or "../data/the_well/datasets", + well_dataset_name=data_cfg.well_dataset_name, + n_steps_input=data_cfg.n_steps_input, + n_steps_output=data_cfg.n_steps_output, + batch_size=data_cfg.batch_size, + train_dataset=WellDataset, + use_normalization=data_cfg.use_normalization, + ) + + logger.info("Training dataset size: %d samples", len(datamodule.train_dataset)) + logger.info("Validation dataset size: %d samples", len(datamodule.val_dataset)) + logger.info("Test dataset size: %d samples", len(datamodule.test_dataset)) + + return datamodule + + # Handle generated and file-based data with AutoEmulateDataModule + dataset_cls = get_dataset_class(data_cfg.dataset_type) + + if source_type == DataSourceType.FILE: + # Load from existing data directory + logger.info("Loading data from %s", data_cfg.data_path) + datamodule = AutoEmulateDataModule( + data_path=data_cfg.data_path, + dataset_cls=dataset_cls, + n_steps_input=data_cfg.n_steps_input, + n_steps_output=data_cfg.n_steps_output, + stride=data_cfg.stride, + input_channel_idxs=data_cfg.input_channel_idxs, + output_channel_idxs=data_cfg.output_channel_idxs, + batch_size=data_cfg.batch_size, + dtype=dtype, + verbose=config.verbose, + use_normalization=data_cfg.use_normalization, + ) + else: # GENERATED + # Generate data from simulator + if simulator is None: + msg = "Simulator required for GENERATED source type" + raise ValueError(msg) + + logger.info( + "Generating %d training, %d validation, and %d test samples", + data_cfg.n_train_samples, + data_cfg.n_valid_samples, + data_cfg.n_test_samples, + ) + + # Generate splits + logger.debug("Generating training data...") + data_train = simulator.forward_samples_spatiotemporal( + data_cfg.n_train_samples, data_cfg.random_seed + ) + data_valid = simulator.forward_samples_spatiotemporal( + data_cfg.n_valid_samples, + data_cfg.random_seed + 1 if data_cfg.random_seed else None, + ) + data_test = simulator.forward_samples_spatiotemporal( + data_cfg.n_test_samples, + data_cfg.random_seed + 2 if data_cfg.random_seed else None, + ) + + data = {"train": data_train, "valid": data_valid, "test": data_test} + + # Optionally save generated data + if config.paths.data_save_path: + logger.info("Saving generated data to %s", config.paths.data_save_path) + save_data_splits( + data, config.paths.data_save_path, config.paths.save_format + ) + + # Create data module + datamodule = AutoEmulateDataModule( + data_path=None, + data=data, + dataset_cls=dataset_cls, + n_steps_input=data_cfg.n_steps_input, + n_steps_output=data_cfg.n_steps_output, + stride=data_cfg.stride, + input_channel_idxs=data_cfg.input_channel_idxs, + output_channel_idxs=data_cfg.output_channel_idxs, + batch_size=data_cfg.batch_size, + dtype=dtype, + verbose=config.verbose, + use_normalization=data_cfg.use_normalization, + ) + + return datamodule + + +def save_data_splits(data, base_path, save_format="h5"): + """Save data splits to disk.""" + base_path = Path(base_path) + logger.debug("Saving data in %s format", save_format) + + for split_name, split_data in data.items(): + split_dir = base_path / split_name + split_dir.mkdir(parents=True, exist_ok=True) + + if save_format == "h5": + file_path = split_dir / "data.h5" + logger.debug("Saving %s data to %s", split_name, file_path) + with h5py.File(file_path, "w") as f: + f.create_dataset("data", data=split_data["data"].numpy()) + if split_data["constant_scalars"] is not None: + f.create_dataset( + "constant_scalars", + data=split_data["constant_scalars"].numpy(), + ) + if split_data["constant_fields"] is not None: + f.create_dataset( + "constant_fields", data=split_data["constant_fields"].numpy() + ) + elif save_format == "pt": + file_path = split_dir / "data.pt" + logger.debug("Saving %s data to %s", split_name, file_path) + torch.save(split_data, file_path) + else: + msg = f"Unknown save format: {save_format}" + raise ValueError(msg) + + +def get_formatter_class(formatter_type: FormatterType): + """Get formatter class from type.""" + formatter_classes = { + FormatterType.DEFAULT_CHANNELS_FIRST: DefaultChannelsFirstFormatter, + FormatterType.DEFAULT_CHANNELS_FIRST_WITH_TIME: ( + DefaultChannelsFirstFormatterWithTime + ), + } + + return formatter_classes[formatter_type] + + +def get_loss_function(loss_type: LossFunctionType): + """Get loss function from type.""" + if loss_type == LossFunctionType.VRMSE: + return VRMSE() + if loss_type == LossFunctionType.MSE: + return torch.nn.MSELoss() + if loss_type == LossFunctionType.MAE: + return torch.nn.L1Loss() + + msg = f"Unknown loss function: {loss_type}" + raise ValueError(msg) + + +def create_lr_scheduler( + config: ExperimentConfig, +) -> Callable[[Any], torch.optim.lr_scheduler.LRScheduler] | None: + """Create learning rate scheduler factory.""" + if config.trainer.lr_scheduler_type is None: + return None + + sched_type = config.trainer.lr_scheduler_type + params = config.trainer.lr_scheduler_params + + if sched_type == LRSchedulerType.STEP_LR: + return lambda opt: torch.optim.lr_scheduler.StepLR( + opt, + step_size=params.get("step_size", 10), + gamma=params.get("gamma", 0.1), + ) + + if sched_type == LRSchedulerType.EXPONENTIAL_LR: + return lambda opt: torch.optim.lr_scheduler.ExponentialLR( + opt, gamma=params.get("gamma", 0.95) + ) + + if sched_type == LRSchedulerType.COSINE_ANNEALING_LR: + return lambda opt: torch.optim.lr_scheduler.CosineAnnealingLR( + opt, T_max=params.get("T_max", 100) + ) + + if sched_type == LRSchedulerType.REDUCE_LR_ON_PLATEAU: + return lambda opt: torch.optim.lr_scheduler.ReduceLROnPlateau( + opt, + mode=params.get("mode", "min"), + factor=params.get("factor", 0.1), + patience=params.get("patience", 10), + ) + + return None + + +def get_optimizer_class(optimizer_type: OptimizerType): + """Get optimizer class from type.""" + optimizer_classes = { + OptimizerType.ADAM: torch.optim.Adam, + OptimizerType.ADAMW: torch.optim.AdamW, + OptimizerType.SGD: torch.optim.SGD, + OptimizerType.RMSPROP: torch.optim.RMSprop, + } + + return optimizer_classes[optimizer_type] + + +def create_trainer_params(config: ExperimentConfig) -> TrainerParams: + """Create TrainerParams from configuration.""" + trainer_cfg = config.trainer + + return TrainerParams( + optimizer_cls=get_optimizer_class(trainer_cfg.optimizer_type), + optimizer_params=trainer_cfg.optimizer_params, + epochs=trainer_cfg.epochs, + checkpoint_frequency=trainer_cfg.checkpoint_frequency, + val_frequency=trainer_cfg.val_frequency, + rollout_val_frequency=trainer_cfg.rollout_val_frequency, + max_rollout_steps=trainer_cfg.max_rollout_steps, + short_validation_length=trainer_cfg.short_validation_length, + make_rollout_videos=trainer_cfg.make_rollout_videos, + lr_scheduler=create_lr_scheduler(config), + amp_type=trainer_cfg.amp_type, # type: ignore TODO fix the types here + num_time_intervals=trainer_cfg.num_time_intervals, + enable_amp=trainer_cfg.enable_amp, + is_distributed=trainer_cfg.is_distributed, + checkpoint_path=trainer_cfg.checkpoint_path, # type: ignore TODO fix the types here + device=trainer_cfg.device, + output_path=str(config.paths.output_dir), + enable_tf_schedule=trainer_cfg.enable_tf_schedule, + tf_params={ + "start": trainer_cfg.tf_params.start, + "end": trainer_cfg.tf_params.end, + "schedule_epochs": trainer_cfg.tf_params.schedule_epochs, + "schedule_type": trainer_cfg.tf_params.schedule_type, + "mode": trainer_cfg.tf_params.mode, + "min_prob": trainer_cfg.tf_params.min_prob, + }, + ) + + +def create_emulator(config: ExperimentConfig, datamodule) -> TheWellEmulator: + """Create emulator from configuration.""" + emulator_type = config.emulator_type + formatter_cls = get_formatter_class(config.formatter_type) + loss_fn = get_loss_function(config.trainer.loss_fn) + trainer_params = create_trainer_params(config) + + # Get model parameters as kwargs + model_params = config.model_params.model_dump(exclude_unset=True) + + emulator_classes = { + EmulatorType.THE_WELL_FNO: TheWellFNO, + EmulatorType.THE_WELL_FNO_WITH_TIME: TheWellFNOWithTime, + EmulatorType.THE_WELL_FNO_WITH_LEARNABLE_WEIGHTS: ( + TheWellFNOWithLearnableWeights + ), + EmulatorType.THE_WELL_AFNO: TheWellAFNO, + EmulatorType.THE_WELL_UNET_CLASSIC: TheWellUNetClassic, + EmulatorType.THE_WELL_UNET_CONVNEXT: TheWellUNetConvNext, + } + + if emulator_type not in emulator_classes: + msg = f"Unknown emulator type: {emulator_type}" + raise ValueError(msg) + + emulator_cls = emulator_classes[emulator_type] + + logger.info("Emulator type: %s", emulator_type.value) + logger.info("Formatter: %s", config.formatter_type.value) + logger.info("Loss function: %s", config.trainer.loss_fn.value) + logger.debug("Model parameters: %s", model_params) + + return emulator_cls( + datamodule=datamodule, + formatter_cls=formatter_cls, + loss_fn=loss_fn, + trainer_params=trainer_params, + **model_params, + ) + + +def run_experiment(config_path: str, output_dir: str | None = None): # noqa: PLR0915 ignoring as mainly logging statements + """Run a complete experiment from config file.""" + # Load configuration + logger.info("Loading configuration from %s", config_path) + config = ExperimentConfig.load_from_yaml(config_path) + + # Override output directory if provided + if output_dir: + config.paths.output_dir = Path(output_dir) + + # Create output directory + config.paths.output_dir.mkdir(parents=True, exist_ok=True) + + # Set up logging + setup_logging( + config.paths.output_dir, + log_level=config.log_level, + verbose=config.verbose, + ) + + # Save config to output directory for reproducibility + config.save_to_yaml(config.paths.output_dir / "config.yaml") + logger.info("Configuration saved to %s", config.paths.output_dir / "config.yaml") + + logger.info("=" * 60) + logger.info("Experiment: %s", config.experiment_name) + if config.description: + logger.info("Description: %s", config.description) + logger.info("Emulator type: %s", config.emulator_type.value) + logger.info("Output directory: %s", config.paths.output_dir) + logger.info("=" * 60) + + # Create simulator if needed + simulator = None + if config.data.data_path is None: + logger.info("Creating simulator...") + simulator = create_simulator(config) + if config.simulator: + logger.info("Simulator type: %s", config.simulator.type.value) + logger.info( + "Simulator params: n=%d, T=%.1f, dt=%.2f", + config.simulator.n, + config.simulator.T, + config.simulator.dt, + ) + + # Generate or load data + logger.info("Preparing data...") + datamodule = generate_or_load_data(config, simulator) + + # Log dataset sizes (handle both AutoEmulate and Well data modules) + logger.info("Training dataset size: %d samples", len(datamodule.train_dataset)) + + # WellDataModule uses 'val_dataset', AutoEmulateDataModule uses 'valid_dataset' + if hasattr(datamodule, "valid_dataset") or hasattr(datamodule, "val_dataset"): + logger.info("Validation dataset size: %d samples", len(datamodule.val_dataset)) + + logger.info("Test dataset size: %d samples", len(datamodule.test_dataset)) + + # Create emulator + logger.info("Creating emulator...") + emulator = create_emulator(config, datamodule) + + # Train emulator + logger.info("") + logger.info("=" * 60) + logger.info("TRAINING") + logger.info("=" * 60) + logger.info("Device: %s", config.trainer.device) + logger.info("Number of epochs: %d", config.trainer.epochs) + logger.info("Batch size: %d", config.data.batch_size) + logger.info( + "Optimizer: %s (lr=%.2e)", + config.trainer.optimizer_type.value, + config.trainer.optimizer_params.get("lr", "N/A"), + ) + logger.info("Teacher forcing enabled: %s", config.trainer.enable_tf_schedule) + if config.trainer.enable_tf_schedule: + logger.info( + " TF schedule: %s (%.2f -> %.2f)", + config.trainer.tf_params.schedule_type, + config.trainer.tf_params.start, + config.trainer.tf_params.end, + ) + logger.info("Starting training...") + logger.info("") + + emulator.fit() + + # Evaluate emulator + logger.info("") + logger.info("=" * 60) + logger.info("EVALUATION") + logger.info("=" * 60) + logger.info("Evaluating on validation set...") + _, valid_results = emulator.trainer.validation_loop( + datamodule.rollout_val_dataloader(), valid_or_test="rollout_valid", full=True + ) # type: ignore # noqa: PGH003 + + logger.info("Evaluating on test set...") + _, test_results = emulator.trainer.validation_loop( + datamodule.rollout_test_dataloader(), valid_or_test="rollout_test", full=True + ) # type: ignore # noqa: PGH003 + + # Save model if path is provided + if config.paths.model_save_path: + save_path = Path(config.paths.model_save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + logger.info("Saving model to %s", save_path) + torch.save(emulator.state_dict(), save_path) + + # Print results + logger.info("") + logger.info("=" * 60) + logger.info("RESULTS") + logger.info("=" * 60) + logger.info("") + logger.info("Validation metrics:") + for key, value in valid_results.items(): + logger.info(" %s: %s", key, value) + logger.info("") + logger.info("Test metrics:") + for key, value in test_results.items(): + logger.info(" %s: %s", key, value) + logger.info("") + logger.info("=" * 60) + logger.info("") + + logger.info("Outputs saved to: %s", config.paths.output_dir) + logger.info("Log file saved to: %s/logs/", config.paths.output_dir) + logger.info("") + logger.info("Experiment complete!") + + return emulator, valid_results, test_results + + +def create_example_config(output_path: str = "example_config.yaml"): + """Create an example configuration file.""" + config = ExperimentConfig( + experiment_name="advection_diffusion_the_well_example", + description="Example configuration for advection-diffusion with The Well FNO", + emulator_type=EmulatorType.THE_WELL_FNO, + formatter_type=FormatterType.DEFAULT_CHANNELS_FIRST, + simulator=SimulatorConfig( + n=64, + T=10.0, + dt=0.1, + return_timeseries=True, + ), + data=DataConfig( + n_train_samples=200, + n_valid_samples=4, + n_test_samples=4, + n_steps_input=4, + n_steps_output=10, + batch_size=4, + ), + model_params=ModelParamsConfig( + modes1=16, + modes2=16, + ), + trainer=TrainerConfig( + epochs=10, + max_rollout_steps=100, + optimizer_params={"lr": 1e-3}, + device="cpu", # Change to "cuda" or "mps" as needed + ), + paths=PathsConfig( + output_dir=Path("./outputs/the_well_example"), + ), + ) + + config.save_to_yaml(output_path) + print(f"Example configuration saved to {output_path}") + + +def main(): + """Run the main entry point.""" + parser = argparse.ArgumentParser( + description="Run spatiotemporal emulation experiments using The Well" + ) + parser.add_argument( + "--config", + type=str, + help="Path to YAML configuration file", + ) + parser.add_argument( + "--output-dir", + type=str, + help="Override output directory from config", + ) + parser.add_argument( + "--create-example", + action="store_true", + help="Create an example configuration file", + ) + parser.add_argument( + "--example-output", + type=str, + default="example_config.yaml", + help="Output path for example config", + ) + + args = parser.parse_args() + + if args.create_example: + create_example_config(args.example_output) + return + + if not args.config: + parser.error("--config is required (or use --create-example)") + + run_experiment(args.config, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/autoemulate/simulations/reaction_diffusion.py b/autoemulate/simulations/reaction_diffusion.py index 013c66d0f..7485198fb 100644 --- a/autoemulate/simulations/reaction_diffusion.py +++ b/autoemulate/simulations/reaction_diffusion.py @@ -22,7 +22,7 @@ def __init__( return_timeseries: bool = False, log_level: str = "progress_bar", n: int = 32, - L: int = 20, + L: float = 20.0, T: float = 10.0, dt: float = 0.1, ): diff --git a/pyproject.toml b/pyproject.toml index 127ad8df9..409f3aa21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,9 @@ dev = [ "hydra-core>=1.3.2", ] spatiotemporal = [ + "pydantic>=2.11.10", "the-well[dev,benchmark] @ git+https://github.com/PolymathicAI/the_well.git@a123419", + # "triton>=3.0.0", ] [tool.setuptools] @@ -167,5 +169,6 @@ convention = "numpy" # Platform-specific triton handling: # - On macOS arm64: fetch triton from the GitHub repo and build from source (uv/pip will attempt to build) override-dependencies = [ - "triton @ git+https://github.com/openai/triton.git@main; sys_platform == 'darwin' and platform_machine == 'arm64'", + #"triton @ git+https://github.com/openai/triton.git@main; sys_platform == 'darwin' and platform_machine == 'arm64'", + "triton @ git+https://github.com/openai/triton.git@main", ] diff --git a/scripts/run_trainer.sh b/scripts/run_trainer.sh new file mode 100644 index 000000000..0fc317c1b --- /dev/null +++ b/scripts/run_trainer.sh @@ -0,0 +1,33 @@ +#!/bin/sh + +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 04:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name turbulent_radiative_layer_2D_download + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 +module load FFmpeg/5.1.2-GCCcore-12.2.0 + +BASE_PATH=/bask/homes/l/ltcx7228/vjgo8416-ai-phy-sys/ + +source .venv/bin/activate + +RUN_PATH=autoemulate/experimental + + +## To fix this error: +## RuntimeError: Deterministic behavior was enabled with either +## torch.use_deterministic_algorithms(True) or at::Context::setDeterministicAlgorithms(true), +## but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. +## To enable deterministic behavior in this case, you must set an environment variable before +## running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. +## For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + +python $RUN_PATH/run_the_well_experiment.py --config $RUN_PATH/configs/advection_diffusion_generated.yaml diff --git a/scripts/run_trainer_advection_diffusion.sh b/scripts/run_trainer_advection_diffusion.sh new file mode 100644 index 000000000..d8d57c76b --- /dev/null +++ b/scripts/run_trainer_advection_diffusion.sh @@ -0,0 +1,33 @@ +#!/bin/sh + +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 24:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name advection_diffusion + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 +module load FFmpeg/5.1.2-GCCcore-12.2.0 + +BASE_PATH=/bask/homes/l/ltcx7228/vjgo8416-ai-phy-sys/ + +source .venv/bin/activate + +RUN_PATH=autoemulate/experimental + + +## To fix this error: +## RuntimeError: Deterministic behavior was enabled with either +## torch.use_deterministic_algorithms(True) or at::Context::setDeterministicAlgorithms(true), +## but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. +## To enable deterministic behavior in this case, you must set an environment variable before +## running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. +## For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + +python $RUN_PATH/run_the_well_experiment.py --config $RUN_PATH/configs/advection_diffusion_generated.yaml diff --git a/scripts/run_trainer_bout.sh b/scripts/run_trainer_bout.sh new file mode 100644 index 000000000..1c1a3cf36 --- /dev/null +++ b/scripts/run_trainer_bout.sh @@ -0,0 +1,33 @@ +#!/bin/sh + +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 04:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name bout + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 +module load FFmpeg/5.1.2-GCCcore-12.2.0 + +BASE_PATH=/bask/homes/l/ltcx7228/vjgo8416-ai-phy-sys/ + +source .venv/bin/activate + +RUN_PATH=autoemulate/experimental + + +## To fix this error: +## RuntimeError: Deterministic behavior was enabled with either +## torch.use_deterministic_algorithms(True) or at::Context::setDeterministicAlgorithms(true), +## but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. +## To enable deterministic behavior in this case, you must set an environment variable before +## running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. +## For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + +python $RUN_PATH/run_the_well_experiment.py --config $RUN_PATH/configs/bout.yaml diff --git a/scripts/run_trainer_reaction_diffusion.sh b/scripts/run_trainer_reaction_diffusion.sh new file mode 100644 index 000000000..6f39c50d0 --- /dev/null +++ b/scripts/run_trainer_reaction_diffusion.sh @@ -0,0 +1,33 @@ +#!/bin/sh + +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 24:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name reaction_diffusion + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 +module load FFmpeg/5.1.2-GCCcore-12.2.0 + +BASE_PATH=/bask/homes/l/ltcx7228/vjgo8416-ai-phy-sys/ + +source .venv/bin/activate + +RUN_PATH=autoemulate/experimental + + +## To fix this error: +## RuntimeError: Deterministic behavior was enabled with either +## torch.use_deterministic_algorithms(True) or at::Context::setDeterministicAlgorithms(true), +## but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. +## To enable deterministic behavior in this case, you must set an environment variable before +## running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. +## For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + +python $RUN_PATH/run_the_well_experiment.py --config $RUN_PATH/configs/reaction_diffusion_generated.yaml diff --git a/scripts/run_trainer_turbulent_radiative_layer_2d.sh b/scripts/run_trainer_turbulent_radiative_layer_2d.sh new file mode 100644 index 000000000..183e41cbe --- /dev/null +++ b/scripts/run_trainer_turbulent_radiative_layer_2d.sh @@ -0,0 +1,33 @@ +#!/bin/sh + +#SBATCH --account=vjgo8416-ai-phy-sys +#SBATCH --qos turing +#SBATCH --time 24:00:00 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --tasks-per-node 1 +#SBATCH --job-name turbulent_radiative_layer_2D + +module purge +module load baskerville +module load bask-apps/live +module load Python/3.10.8-GCCcore-12.2.0 +module load FFmpeg/5.1.2-GCCcore-12.2.0 + +BASE_PATH=/bask/homes/l/ltcx7228/vjgo8416-ai-phy-sys/ + +source .venv/bin/activate + +RUN_PATH=autoemulate/experimental + + +## To fix this error: +## RuntimeError: Deterministic behavior was enabled with either +## torch.use_deterministic_algorithms(True) or at::Context::setDeterministicAlgorithms(true), +## but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. +## To enable deterministic behavior in this case, you must set an environment variable before +## running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. +## For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + +python $RUN_PATH/run_the_well_experiment.py --config $RUN_PATH/configs/turbulent_radiative_layer_2d.yaml diff --git a/uv.lock b/uv.lock index 6f327c423..abed0473c 100644 --- a/uv.lock +++ b/uv.lock @@ -11,10 +11,7 @@ resolution-markers = [ ] [manifest] -overrides = [ - { name = "triton", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", git = "https://github.com/openai/triton.git?rev=main" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", specifier = "==3.4.0" }, -] +overrides = [{ name = "triton", git = "https://github.com/openai/triton.git?rev=main" }] [[package]] name = "accessible-pygments" @@ -37,6 +34,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92", size = 13511, upload-time = "2024-01-10T00:56:08.388Z" }, ] +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -162,6 +168,7 @@ docs = [ { name = "sphinx-copybutton" }, ] spatiotemporal = [ + { name = "pydantic" }, { name = "the-well", extra = ["benchmark", "dev"] }, ] @@ -190,6 +197,7 @@ requires-dist = [ { name = "pandas", specifier = ">=2.1" }, { name = "plotnine", marker = "extra == 'dev'", specifier = ">=0.13.6" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" }, + { name = "pydantic", marker = "extra == 'spatiotemporal'", specifier = ">=2.11.10" }, { name = "pyright", marker = "extra == 'dev'", specifier = "==1.1.405" }, { name = "pyro-ppl", specifier = ">=1.9.1" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" }, @@ -2281,6 +2289,91 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552, upload-time = "2024-03-30T13:22:20.476Z" }, ] +[[package]] +name = "pydantic" +version = "2.11.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/54/ecab642b3bed45f7d5f59b38443dcb36ef50f85af192e6ece103dbfe9587/pydantic-2.11.10.tar.gz", hash = "sha256:dc280f0982fbda6c38fada4e476dc0a4f3aeaf9c6ad4c28df68a666ec3c61423", size = 788494, upload-time = "2025-10-04T10:40:41.338Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/1f/73c53fcbfb0b5a78f91176df41945ca466e71e9d9d836e5c522abda39ee7/pydantic-2.11.10-py3-none-any.whl", hash = "sha256:802a655709d49bd004c31e865ef37da30b540786a46bfce02333e0e24b5fe29a", size = 444823, upload-time = "2025-10-04T10:40:39.055Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817, upload-time = "2025-04-23T18:30:43.919Z" }, + { url = "https://files.pythonhosted.org/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357, upload-time = "2025-04-23T18:30:46.372Z" }, + { url = "https://files.pythonhosted.org/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011, upload-time = "2025-04-23T18:30:47.591Z" }, + { url = "https://files.pythonhosted.org/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730, upload-time = "2025-04-23T18:30:49.328Z" }, + { url = "https://files.pythonhosted.org/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178, upload-time = "2025-04-23T18:30:50.907Z" }, + { url = "https://files.pythonhosted.org/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462, upload-time = "2025-04-23T18:30:52.083Z" }, + { url = "https://files.pythonhosted.org/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652, upload-time = "2025-04-23T18:30:53.389Z" }, + { url = "https://files.pythonhosted.org/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306, upload-time = "2025-04-23T18:30:54.661Z" }, + { url = "https://files.pythonhosted.org/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720, upload-time = "2025-04-23T18:30:56.11Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915, upload-time = "2025-04-23T18:30:57.501Z" }, + { url = "https://files.pythonhosted.org/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884, upload-time = "2025-04-23T18:30:58.867Z" }, + { url = "https://files.pythonhosted.org/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496, upload-time = "2025-04-23T18:31:00.078Z" }, + { url = "https://files.pythonhosted.org/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019, upload-time = "2025-04-23T18:31:01.335Z" }, + { url = "https://files.pythonhosted.org/packages/3f/8d/71db63483d518cbbf290261a1fc2839d17ff89fce7089e08cad07ccfce67/pydantic_core-2.33.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4c5b0a576fb381edd6d27f0a85915c6daf2f8138dc5c267a57c08a62900758c7", size = 2028584, upload-time = "2025-04-23T18:31:03.106Z" }, + { url = "https://files.pythonhosted.org/packages/24/2f/3cfa7244ae292dd850989f328722d2aef313f74ffc471184dc509e1e4e5a/pydantic_core-2.33.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e799c050df38a639db758c617ec771fd8fb7a5f8eaaa4b27b101f266b216a246", size = 1855071, upload-time = "2025-04-23T18:31:04.621Z" }, + { url = "https://files.pythonhosted.org/packages/b3/d3/4ae42d33f5e3f50dd467761304be2fa0a9417fbf09735bc2cce003480f2a/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc46a01bf8d62f227d5ecee74178ffc448ff4e5197c756331f71efcc66dc980f", size = 1897823, upload-time = "2025-04-23T18:31:06.377Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f3/aa5976e8352b7695ff808599794b1fba2a9ae2ee954a3426855935799488/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a144d4f717285c6d9234a66778059f33a89096dfb9b39117663fd8413d582dcc", size = 1983792, upload-time = "2025-04-23T18:31:07.93Z" }, + { url = "https://files.pythonhosted.org/packages/d5/7a/cda9b5a23c552037717f2b2a5257e9b2bfe45e687386df9591eff7b46d28/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cf6373c21bc80b2e0dc88444f41ae60b2f070ed02095754eb5a01df12256de", size = 2136338, upload-time = "2025-04-23T18:31:09.283Z" }, + { url = "https://files.pythonhosted.org/packages/2b/9f/b8f9ec8dd1417eb9da784e91e1667d58a2a4a7b7b34cf4af765ef663a7e5/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dc625f4aa79713512d1976fe9f0bc99f706a9dee21dfd1810b4bbbf228d0e8a", size = 2730998, upload-time = "2025-04-23T18:31:11.7Z" }, + { url = "https://files.pythonhosted.org/packages/47/bc/cd720e078576bdb8255d5032c5d63ee5c0bf4b7173dd955185a1d658c456/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b21b5549499972441da4758d662aeea93f1923f953e9cbaff14b8b9565aef", size = 2003200, upload-time = "2025-04-23T18:31:13.536Z" }, + { url = "https://files.pythonhosted.org/packages/ca/22/3602b895ee2cd29d11a2b349372446ae9727c32e78a94b3d588a40fdf187/pydantic_core-2.33.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bdc25f3681f7b78572699569514036afe3c243bc3059d3942624e936ec93450e", size = 2113890, upload-time = "2025-04-23T18:31:15.011Z" }, + { url = "https://files.pythonhosted.org/packages/ff/e6/e3c5908c03cf00d629eb38393a98fccc38ee0ce8ecce32f69fc7d7b558a7/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fe5b32187cbc0c862ee201ad66c30cf218e5ed468ec8dc1cf49dec66e160cc4d", size = 2073359, upload-time = "2025-04-23T18:31:16.393Z" }, + { url = "https://files.pythonhosted.org/packages/12/e7/6a36a07c59ebefc8777d1ffdaf5ae71b06b21952582e4b07eba88a421c79/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:bc7aee6f634a6f4a95676fcb5d6559a2c2a390330098dba5e5a5f28a2e4ada30", size = 2245883, upload-time = "2025-04-23T18:31:17.892Z" }, + { url = "https://files.pythonhosted.org/packages/16/3f/59b3187aaa6cc0c1e6616e8045b284de2b6a87b027cce2ffcea073adf1d2/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:235f45e5dbcccf6bd99f9f472858849f73d11120d76ea8707115415f8e5ebebf", size = 2241074, upload-time = "2025-04-23T18:31:19.205Z" }, + { url = "https://files.pythonhosted.org/packages/e0/ed/55532bb88f674d5d8f67ab121a2a13c385df382de2a1677f30ad385f7438/pydantic_core-2.33.2-cp311-cp311-win32.whl", hash = "sha256:6368900c2d3ef09b69cb0b913f9f8263b03786e5b2a387706c5afb66800efd51", size = 1910538, upload-time = "2025-04-23T18:31:20.541Z" }, + { url = "https://files.pythonhosted.org/packages/fe/1b/25b7cccd4519c0b23c2dd636ad39d381abf113085ce4f7bec2b0dc755eb1/pydantic_core-2.33.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e063337ef9e9820c77acc768546325ebe04ee38b08703244c1309cccc4f1bab", size = 1952909, upload-time = "2025-04-23T18:31:22.371Z" }, + { url = "https://files.pythonhosted.org/packages/49/a9/d809358e49126438055884c4366a1f6227f0f84f635a9014e2deb9b9de54/pydantic_core-2.33.2-cp311-cp311-win_arm64.whl", hash = "sha256:6b99022f1d19bc32a4c2a0d544fc9a76e3be90f0b3f4af413f87d38749300e65", size = 1897786, upload-time = "2025-04-23T18:31:24.161Z" }, + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199, upload-time = "2025-04-23T18:31:31.025Z" }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296, upload-time = "2025-04-23T18:31:32.514Z" }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109, upload-time = "2025-04-23T18:31:33.958Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028, upload-time = "2025-04-23T18:31:39.095Z" }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044, upload-time = "2025-04-23T18:31:41.034Z" }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881, upload-time = "2025-04-23T18:31:42.757Z" }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034, upload-time = "2025-04-23T18:31:44.304Z" }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187, upload-time = "2025-04-23T18:31:45.891Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628, upload-time = "2025-04-23T18:31:47.819Z" }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866, upload-time = "2025-04-23T18:31:49.635Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894, upload-time = "2025-04-23T18:31:51.609Z" }, + { url = "https://files.pythonhosted.org/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982, upload-time = "2025-04-23T18:32:53.14Z" }, + { url = "https://files.pythonhosted.org/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412, upload-time = "2025-04-23T18:32:55.52Z" }, + { url = "https://files.pythonhosted.org/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749, upload-time = "2025-04-23T18:32:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527, upload-time = "2025-04-23T18:32:59.771Z" }, + { url = "https://files.pythonhosted.org/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225, upload-time = "2025-04-23T18:33:04.51Z" }, + { url = "https://files.pythonhosted.org/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490, upload-time = "2025-04-23T18:33:06.391Z" }, + { url = "https://files.pythonhosted.org/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525, upload-time = "2025-04-23T18:33:08.44Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446, upload-time = "2025-04-23T18:33:10.313Z" }, + { url = "https://files.pythonhosted.org/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678, upload-time = "2025-04-23T18:33:12.224Z" }, + { url = "https://files.pythonhosted.org/packages/7b/27/d4ae6487d73948d6f20dddcd94be4ea43e74349b56eba82e9bdee2d7494c/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:dd14041875d09cc0f9308e37a6f8b65f5585cf2598a53aa0123df8b129d481f8", size = 2025200, upload-time = "2025-04-23T18:33:14.199Z" }, + { url = "https://files.pythonhosted.org/packages/f1/b8/b3cb95375f05d33801024079b9392a5ab45267a63400bf1866e7ce0f0de4/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d87c561733f66531dced0da6e864f44ebf89a8fba55f31407b00c2f7f9449593", size = 1859123, upload-time = "2025-04-23T18:33:16.555Z" }, + { url = "https://files.pythonhosted.org/packages/05/bc/0d0b5adeda59a261cd30a1235a445bf55c7e46ae44aea28f7bd6ed46e091/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f82865531efd18d6e07a04a17331af02cb7a651583c418df8266f17a63c6612", size = 1892852, upload-time = "2025-04-23T18:33:18.513Z" }, + { url = "https://files.pythonhosted.org/packages/3e/11/d37bdebbda2e449cb3f519f6ce950927b56d62f0b84fd9cb9e372a26a3d5/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bfb5112df54209d820d7bf9317c7a6c9025ea52e49f46b6a2060104bba37de7", size = 2067484, upload-time = "2025-04-23T18:33:20.475Z" }, + { url = "https://files.pythonhosted.org/packages/8c/55/1f95f0a05ce72ecb02a8a8a1c3be0579bbc29b1d5ab68f1378b7bebc5057/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64632ff9d614e5eecfb495796ad51b0ed98c453e447a76bcbeeb69615079fc7e", size = 2108896, upload-time = "2025-04-23T18:33:22.501Z" }, + { url = "https://files.pythonhosted.org/packages/53/89/2b2de6c81fa131f423246a9109d7b2a375e83968ad0800d6e57d0574629b/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f889f7a40498cc077332c7ab6b4608d296d852182211787d4f3ee377aaae66e8", size = 2069475, upload-time = "2025-04-23T18:33:24.528Z" }, + { url = "https://files.pythonhosted.org/packages/b8/e9/1f7efbe20d0b2b10f6718944b5d8ece9152390904f29a78e68d4e7961159/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:de4b83bb311557e439b9e186f733f6c645b9417c84e2eb8203f3f820a4b988bf", size = 2239013, upload-time = "2025-04-23T18:33:26.621Z" }, + { url = "https://files.pythonhosted.org/packages/3c/b2/5309c905a93811524a49b4e031e9851a6b00ff0fb668794472ea7746b448/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f68293f055f51b51ea42fafc74b6aad03e70e191799430b90c13d643059ebb", size = 2238715, upload-time = "2025-04-23T18:33:28.656Z" }, + { url = "https://files.pythonhosted.org/packages/32/56/8a7ca5d2cd2cda1d245d34b1c9a942920a718082ae8e54e5f3e5a58b7add/pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1", size = 2066757, upload-time = "2025-04-23T18:33:30.645Z" }, +] + [[package]] name = "pydata-sphinx-theme" version = "0.15.4" @@ -3488,8 +3581,7 @@ dependencies = [ { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", version = "3.4.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "triton", version = "3.5.0+gitcb3ef9fa", source = { git = "https://github.com/openai/triton.git?rev=main#cb3ef9faca161b3229298bd655763dc52ec95042" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "triton" }, { name = "typing-extensions" }, ] wheels = [ @@ -3515,8 +3607,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "torch" }, - { name = "triton", version = "3.4.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "triton", version = "3.5.0+gitcb3ef9fa", source = { git = "https://github.com/openai/triton.git?rev=main#cb3ef9faca161b3229298bd655763dc52ec95042" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, + { name = "triton" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e9/02/2bcd5eb625dec116191ec40f66a04070b65d5251cae4865dc27af0d1fa24/torch_harmonics-0.6.5.tar.gz", hash = "sha256:e467d04bc58eb2dc800eb21870025407d38ebcbf8df4de479bd5b4915daf987e", size = 47767, upload-time = "2024-01-30T14:09:20.23Z" } wheels = [ @@ -3634,33 +3725,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, ] -[[package]] -name = "triton" -version = "3.4.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", -] -dependencies = [ - { name = "setuptools", marker = "sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, - { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" }, - { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" }, -] - [[package]] name = "triton" version = "3.5.0+gitcb3ef9fa" source = { git = "https://github.com/openai/triton.git?rev=main#cb3ef9faca161b3229298bd655763dc52ec95042" } -resolution-markers = [ - "python_full_version >= '3.12' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'linux'", - "python_full_version < '3.11' and sys_platform != 'linux'", -] [[package]] name = "typing-extensions" @@ -3671,6 +3739,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" }, ] +[[package]] +name = "typing-inspection" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac53367ae42139cf4b1ca5f36bb3dc6c9d33acdb43655/typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464", size = 75949, upload-time = "2025-10-01T02:14:41.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, +] + [[package]] name = "tzdata" version = "2025.2"