diff --git a/examples/reference_full_config/full_config.yml b/examples/reference_full_config/full_config.yml index a9525b085..8a3e2c0cb 100644 --- a/examples/reference_full_config/full_config.yml +++ b/examples/reference_full_config/full_config.yml @@ -3,16 +3,16 @@ # Most experiments do not require a runner.yml # When a runner.yml is used, it is not required to specify all configuration settings. -# For typical usage, please refer to the example runner yamls in examples/example_runner_yamls/, -# instead of using this full configuration file. +# For typical usage, please refer to the example runner yamls in examples/example_runner_yamls/, +# instead of using this full configuration file. ####################### # ExperimentSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L247 -# Overall experiment settings +# Overall experiment settings experiment_settings: mode: predict [predict | train] - output_dir: ./ + output_dir: ./ log_dir: null seeds: - 42 @@ -20,10 +20,11 @@ experiment_settings: use_msa_server: true use_templates: true skip_existing: false + record_memory_snapshot: false # PLTrainerArgs: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L121 # Configuration settings for pytorch lightning -pl_trainer_args: +pl_trainer_args: max_epochs: 1000 accelerator: gpu precision: 32-true @@ -39,7 +40,7 @@ pl_trainer_args: # ModelUpdate: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/project_entry.py#L28 # Custom model settings and presets -model_update: +model_update: presets: # [train | predict | low_mem ] - predict - low_mem @@ -52,8 +53,8 @@ inference_ckpt_path: $HOME/.openfold3/of3_ft3_v1.pt cache_path: $HOME/.openfold3 # DataModuleArgs: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L110 -data_module_args: - batch_size: 1 # Note: Not tested for batch sizes > 1 +data_module_args: + batch_size: 1 # Note: Not tested for batch sizes > 1 data_seed: 42 num_workers: 10 num_workers_validation: 4 # Change num workers for thread limited workflows @@ -61,11 +62,11 @@ data_module_args: # DatasetConfigKwargs: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/dataset_configs.py#L270 # Arguments for creating template and MSA features -dataset_config_kwargs: +dataset_config_kwargs: ccd_file_path: null # if null, uses CCD from Biotite # MSA Settings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/dataset_config_components.py#L32 # Use this section to customize parsing of MSAs into features, more information in docs/source/precomputed_msa_how_to.md - msa: + msa: max_rows_paired: 8191 max_rows: 16384 subsample_with_bands: false @@ -117,7 +118,7 @@ dataset_config_kwargs: # OutputWriterSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L141 # Configure the output formats / content -output_writer_settings: +output_writer_settings: structure_format: cif [cif | pdb | cif.gz] full_confidence_output_format: json [json | npz] full_confidence_output_dtype: float16 [float32 | float16] @@ -126,18 +127,18 @@ output_writer_settings: write_full_confidence_scores: true # MSAComputationSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/tools/colabfold_msa_server.py#L904 -# Configure MSA server settings for colabfold -msa_computation_settings: +# Configure MSA server settings for colabfold +msa_computation_settings: msa_file_format: npz server_user_agent: openfold server_url: https://api.colabfold.com/ save_mappings: true - msa_output_directory: /of3-of-/colabfold_msas/ + msa_output_directory: /of3-of-/colabfold_msas/ cleanup_msa_dir: true # TemplatePreprocessorSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/pipelines/preprocessing/template.py#L1459 # Configure template processing settings, more information in docs/source/template_how_to.md -template_preprocessor_settings: +template_preprocessor_settings: mode: predict moltypes: - 0 diff --git a/openfold3/core/utils/callbacks.py b/openfold3/core/utils/callbacks.py index 15fad920d..4ee96335b 100644 --- a/openfold3/core/utils/callbacks.py +++ b/openfold3/core/utils/callbacks.py @@ -86,6 +86,54 @@ def on_predict_batch_end( runtime_file.write_text(json.dumps(runtime_json, indent=4)) +class MemorySnapshotCallback(pl.Callback): + """Record peak GPU memory and dump a torch memory-history snapshot per batch. + + Must be registered AFTER PredictTimer so the snapshot dump runs outside + the timer's measurement window in on_predict_batch_end. + """ + + def __init__(self, output_dir: Path): + super().__init__() + self.output_dir = Path(output_dir) + + def on_predict_batch_start( + self, trainer, pl_module, batch, batch_idx, dataloader_idx: int = 0 + ): + if not torch.cuda.is_available(): + return + torch.cuda.reset_peak_memory_stats() + torch.cuda.memory._record_memory_history( + max_entries=10_000_000, enabled="all" + ) + + def on_predict_batch_end( + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + dataloader_idx=0, + ): + if not torch.cuda.is_available(): + return + torch.cuda.synchronize() + peak_gib = torch.cuda.max_memory_allocated() / (1024**3) + logger.info(f"[memsnap] peak={peak_gib:.2f} GiB") + batch_size = len(batch["query_id"]) + for b in range(batch_size): + seed = batch["seed"][b] + query_id = batch["query_id"][b] + snapshot_path = ( + self.output_dir / query_id / f"seed_{seed}" / "mem_snapshot.pkl" + ) + snapshot_path.parent.mkdir(parents=True, exist_ok=True) + torch.cuda.memory._dump_snapshot(str(snapshot_path)) + logger.info(f"[memsnap] wrote {snapshot_path}") + torch.cuda.memory._record_memory_history(enabled=None) + + def set_seed_for_rank(seed: int, rank: int) -> None: """ Sets the seed for all relevant random number generators on a specific rank. diff --git a/openfold3/entry_points/experiment_runner.py b/openfold3/entry_points/experiment_runner.py index ff7637e26..3fab2ec3a 100644 --- a/openfold3/entry_points/experiment_runner.py +++ b/openfold3/entry_points/experiment_runner.py @@ -44,6 +44,7 @@ from openfold3.core.runners.writer import OF3OutputWriter from openfold3.core.utils.callbacks import ( LogInferenceQuerySet, + MemorySnapshotCallback, PredictTimer, RankSpecificSeedCallback, ) @@ -562,6 +563,7 @@ def __init__( use_msa_server: bool = False, use_templates: bool = False, output_dir: Path | None = None, + record_memory_snapshot: bool = False, ): super().__init__(experiment_config) @@ -579,6 +581,7 @@ def __init__( output_dir, use_msa_server, use_templates, + record_memory_snapshot, ) def set_num_diffusion_samples(self, num_diffusion_samples: int) -> None: @@ -609,6 +612,7 @@ def update_config_with_cli_args( output_dir: Path | None, use_msa_server: bool = False, use_templates: bool = False, + record_memory_snapshot: bool = False, ): """Updates configuration given command line args.""" if output_dir: @@ -628,6 +632,9 @@ def update_config_with_cli_args( if use_templates: self.experiment_config.experiment_settings.use_templates = True + if record_memory_snapshot: + self.experiment_config.experiment_settings.record_memory_snapshot = True + @cached_property def use_msa_server(self) -> bool: return self.experiment_config.experiment_settings.use_msa_server @@ -732,6 +739,8 @@ def callbacks(self): PredictTimer(self.output_dir), LogInferenceQuerySet(self.output_dir), ] + if self.experiment_config.experiment_settings.record_memory_snapshot: + _callbacks.append(MemorySnapshotCallback(self.output_dir)) return _callbacks @cached_property diff --git a/openfold3/entry_points/validator.py b/openfold3/entry_points/validator.py index 33f0f03f3..75e560b0c 100644 --- a/openfold3/entry_points/validator.py +++ b/openfold3/entry_points/validator.py @@ -271,6 +271,7 @@ class InferenceExperimentSettings(ExperimentSettings): use_msa_server: bool = False use_templates: bool = False skip_existing: bool = False + record_memory_snapshot: bool = False @model_validator(mode="after") def generate_seeds(self): diff --git a/openfold3/run_openfold.py b/openfold3/run_openfold.py index 732e00783..2f41743c3 100644 --- a/openfold3/run_openfold.py +++ b/openfold3/run_openfold.py @@ -146,6 +146,15 @@ def train(runner_yaml: Path, seed: int | None = None, data_seed: int | None = No required=False, help="Output directory for writing results", ) +@click.option( + "--record-memory-snapshot", + "--record_memory_snapshot", + is_flag=True, + default=False, + help="Record a torch CUDA memory-history snapshot per predict batch to " + "//seed_/mem_snapshot.pkl. Snapshots are large " + "(~1-1.5 GiB per batch for big sequences).", +) def predict( query_json: Path, inference_ckpt_path: Path | None = None, @@ -156,6 +165,7 @@ def predict( use_msa_server: bool = True, use_templates: bool = True, output_dir: Path | None = None, + record_memory_snapshot: bool = False, ): """Perform inference on a set of queries defined in the query_json.""" _torch_gpu_setup() @@ -185,6 +195,7 @@ def predict( use_msa_server, use_templates, output_dir, + record_memory_snapshot, ) # Load inference query set @@ -225,12 +236,12 @@ def align_msa_server( msa_computation_settings_yaml: Path | None = None, ): """Run MSA server alignment only with ColabFold MSA server. - + Example command: python run_openfold.py align-msa-server \ --query_json query_example.json \ --output_dir output/msa_server_test \ - + More settings can be specified using the `msa_computation_settings_yaml` flag An example yaml file is provided in `examples/msa_server.yml` """ diff --git a/openfold3/tests/test_entry_points.py b/openfold3/tests/test_entry_points.py index 179528584..adbdfc2e9 100644 --- a/openfold3/tests/test_entry_points.py +++ b/openfold3/tests/test_entry_points.py @@ -29,6 +29,7 @@ from openfold3 import setup_openfold from openfold3.core.config import config_utils from openfold3.core.data.framework.data_module import DataModuleConfig +from openfold3.core.utils.callbacks import MemorySnapshotCallback from openfold3.entry_points.experiment_runner import ( InferenceExperimentRunner, TrainingExperimentRunner, @@ -442,6 +443,23 @@ def test_use_templates_cli(self, use_templates_cli_arg, tmp_path, dummy_ckpt_fil ) assert expt_runner.use_templates == use_templates_cli_arg + @pytest.mark.parametrize("record_memory_snapshot_cli_arg", [True, False]) + def test_record_memory_snapshot_cli( + self, record_memory_snapshot_cli_arg, dummy_ckpt_file + ): + expt_config = InferenceExperimentConfig(inference_ckpt_path=dummy_ckpt_file) + expt_runner = InferenceExperimentRunner( + expt_config, record_memory_snapshot=record_memory_snapshot_cli_arg + ) + assert ( + expt_config.experiment_settings.record_memory_snapshot + == record_memory_snapshot_cli_arg + ) + has_callback = any( + isinstance(cb, MemorySnapshotCallback) for cb in expt_runner.callbacks + ) + assert has_callback == record_memory_snapshot_cli_arg + def test_seeding_from_num_seeds(self, dummy_ckpt_file): expt_config = InferenceExperimentConfig(inference_ckpt_path=dummy_ckpt_file) num_seeds = 7