Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions examples/reference_full_config/full_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@
# Most experiments do not require a runner.yml
# When a runner.yml is used, it is not required to specify all configuration settings.

# For typical usage, please refer to the example runner yamls in examples/example_runner_yamls/,
# instead of using this full configuration file.
# For typical usage, please refer to the example runner yamls in examples/example_runner_yamls/,
# instead of using this full configuration file.

#######################

# ExperimentSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L247
# Overall experiment settings
# Overall experiment settings
experiment_settings:
mode: predict [predict | train]
output_dir: ./
output_dir: ./
log_dir: null
seeds:
- 42
num_seeds: null
use_msa_server: true
use_templates: true
skip_existing: false
record_memory_snapshot: false

# PLTrainerArgs: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L121
# Configuration settings for pytorch lightning
pl_trainer_args:
pl_trainer_args:
max_epochs: 1000
accelerator: gpu
precision: 32-true
Expand All @@ -39,7 +40,7 @@ pl_trainer_args:

# ModelUpdate: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/project_entry.py#L28
# Custom model settings and presets
model_update:
model_update:
presets: # [train | predict | low_mem ]
- predict
- low_mem
Expand All @@ -52,20 +53,20 @@ inference_ckpt_path: $HOME/.openfold3/of3_ft3_v1.pt
cache_path: $HOME/.openfold3

# DataModuleArgs: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L110
data_module_args:
batch_size: 1 # Note: Not tested for batch sizes > 1
data_module_args:
batch_size: 1 # Note: Not tested for batch sizes > 1
data_seed: 42
num_workers: 10
num_workers_validation: 4 # Change num workers for thread limited workflows
epoch_len: 4

# DatasetConfigKwargs: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/dataset_configs.py#L270
# Arguments for creating template and MSA features
dataset_config_kwargs:
dataset_config_kwargs:
ccd_file_path: null # if null, uses CCD from Biotite
# MSA Settings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/dataset_config_components.py#L32
# Use this section to customize parsing of MSAs into features, more information in docs/source/precomputed_msa_how_to.md
msa:
msa:
max_rows_paired: 8191
max_rows: 16384
subsample_with_bands: false
Expand Down Expand Up @@ -117,7 +118,7 @@ dataset_config_kwargs:

# OutputWriterSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/entry_points/validator.py#L141
# Configure the output formats / content
output_writer_settings:
output_writer_settings:
structure_format: cif [cif | pdb | cif.gz]
full_confidence_output_format: json [json | npz]
full_confidence_output_dtype: float16 [float32 | float16]
Expand All @@ -126,18 +127,18 @@ output_writer_settings:
write_full_confidence_scores: true

# MSAComputationSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/tools/colabfold_msa_server.py#L904
# Configure MSA server settings for colabfold
msa_computation_settings:
# Configure MSA server settings for colabfold
msa_computation_settings:
msa_file_format: npz
server_user_agent: openfold
server_url: https://api.colabfold.com/
save_mappings: true
msa_output_directory: <tmp-dir>/of3-of-<user>/colabfold_msas/
msa_output_directory: <tmp-dir>/of3-of-<user>/colabfold_msas/
cleanup_msa_dir: true

# TemplatePreprocessorSettings: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/data/pipelines/preprocessing/template.py#L1459
# Configure template processing settings, more information in docs/source/template_how_to.md
template_preprocessor_settings:
template_preprocessor_settings:
mode: predict
moltypes:
- 0
Expand Down
48 changes: 48 additions & 0 deletions openfold3/core/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,54 @@ def on_predict_batch_end(
runtime_file.write_text(json.dumps(runtime_json, indent=4))


class MemorySnapshotCallback(pl.Callback):
"""Record peak GPU memory and dump a torch memory-history snapshot per batch.

Must be registered AFTER PredictTimer so the snapshot dump runs outside
the timer's measurement window in on_predict_batch_end.
"""

def __init__(self, output_dir: Path):
super().__init__()
self.output_dir = Path(output_dir)

def on_predict_batch_start(
self, trainer, pl_module, batch, batch_idx, dataloader_idx: int = 0
):
if not torch.cuda.is_available():
return
torch.cuda.reset_peak_memory_stats()
torch.cuda.memory._record_memory_history(
max_entries=10_000_000, enabled="all"
)

def on_predict_batch_end(
self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
dataloader_idx=0,
):
if not torch.cuda.is_available():
return
torch.cuda.synchronize()
peak_gib = torch.cuda.max_memory_allocated() / (1024**3)
logger.info(f"[memsnap] peak={peak_gib:.2f} GiB")
batch_size = len(batch["query_id"])
for b in range(batch_size):
seed = batch["seed"][b]
query_id = batch["query_id"][b]
snapshot_path = (
self.output_dir / query_id / f"seed_{seed}" / "mem_snapshot.pkl"
)
snapshot_path.parent.mkdir(parents=True, exist_ok=True)
torch.cuda.memory._dump_snapshot(str(snapshot_path))
logger.info(f"[memsnap] wrote {snapshot_path}")
torch.cuda.memory._record_memory_history(enabled=None)


def set_seed_for_rank(seed: int, rank: int) -> None:
"""
Sets the seed for all relevant random number generators on a specific rank.
Expand Down
9 changes: 9 additions & 0 deletions openfold3/entry_points/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from openfold3.core.runners.writer import OF3OutputWriter
from openfold3.core.utils.callbacks import (
LogInferenceQuerySet,
MemorySnapshotCallback,
PredictTimer,
RankSpecificSeedCallback,
)
Expand Down Expand Up @@ -562,6 +563,7 @@ def __init__(
use_msa_server: bool = False,
use_templates: bool = False,
output_dir: Path | None = None,
record_memory_snapshot: bool = False,
):
super().__init__(experiment_config)

Expand All @@ -579,6 +581,7 @@ def __init__(
output_dir,
use_msa_server,
use_templates,
record_memory_snapshot,
)

def set_num_diffusion_samples(self, num_diffusion_samples: int) -> None:
Expand Down Expand Up @@ -609,6 +612,7 @@ def update_config_with_cli_args(
output_dir: Path | None,
use_msa_server: bool = False,
use_templates: bool = False,
record_memory_snapshot: bool = False,
):
"""Updates configuration given command line args."""
if output_dir:
Expand All @@ -628,6 +632,9 @@ def update_config_with_cli_args(
if use_templates:
self.experiment_config.experiment_settings.use_templates = True

if record_memory_snapshot:
self.experiment_config.experiment_settings.record_memory_snapshot = True

@cached_property
def use_msa_server(self) -> bool:
return self.experiment_config.experiment_settings.use_msa_server
Expand Down Expand Up @@ -732,6 +739,8 @@ def callbacks(self):
PredictTimer(self.output_dir),
LogInferenceQuerySet(self.output_dir),
]
if self.experiment_config.experiment_settings.record_memory_snapshot:
_callbacks.append(MemorySnapshotCallback(self.output_dir))
return _callbacks

@cached_property
Expand Down
1 change: 1 addition & 0 deletions openfold3/entry_points/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class InferenceExperimentSettings(ExperimentSettings):
use_msa_server: bool = False
use_templates: bool = False
skip_existing: bool = False
record_memory_snapshot: bool = False

@model_validator(mode="after")
def generate_seeds(self):
Expand Down
15 changes: 13 additions & 2 deletions openfold3/run_openfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ def train(runner_yaml: Path, seed: int | None = None, data_seed: int | None = No
required=False,
help="Output directory for writing results",
)
@click.option(
"--record-memory-snapshot",
"--record_memory_snapshot",
is_flag=True,
default=False,
help="Record a torch CUDA memory-history snapshot per predict batch to "
"<output_dir>/<query_id>/seed_<n>/mem_snapshot.pkl. Snapshots are large "
"(~1-1.5 GiB per batch for big sequences).",
)
def predict(
query_json: Path,
inference_ckpt_path: Path | None = None,
Expand All @@ -156,6 +165,7 @@ def predict(
use_msa_server: bool = True,
use_templates: bool = True,
output_dir: Path | None = None,
record_memory_snapshot: bool = False,
):
"""Perform inference on a set of queries defined in the query_json."""
_torch_gpu_setup()
Expand Down Expand Up @@ -185,6 +195,7 @@ def predict(
use_msa_server,
use_templates,
output_dir,
record_memory_snapshot,
)

# Load inference query set
Expand Down Expand Up @@ -225,12 +236,12 @@ def align_msa_server(
msa_computation_settings_yaml: Path | None = None,
):
"""Run MSA server alignment only with ColabFold MSA server.

Example command:
python run_openfold.py align-msa-server \
--query_json query_example.json \
--output_dir output/msa_server_test \

More settings can be specified using the `msa_computation_settings_yaml` flag
An example yaml file is provided in `examples/msa_server.yml`
"""
Expand Down
18 changes: 18 additions & 0 deletions openfold3/tests/test_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from openfold3 import setup_openfold
from openfold3.core.config import config_utils
from openfold3.core.data.framework.data_module import DataModuleConfig
from openfold3.core.utils.callbacks import MemorySnapshotCallback
from openfold3.entry_points.experiment_runner import (
InferenceExperimentRunner,
TrainingExperimentRunner,
Expand Down Expand Up @@ -442,6 +443,23 @@ def test_use_templates_cli(self, use_templates_cli_arg, tmp_path, dummy_ckpt_fil
)
assert expt_runner.use_templates == use_templates_cli_arg

@pytest.mark.parametrize("record_memory_snapshot_cli_arg", [True, False])
def test_record_memory_snapshot_cli(
self, record_memory_snapshot_cli_arg, dummy_ckpt_file
):
expt_config = InferenceExperimentConfig(inference_ckpt_path=dummy_ckpt_file)
expt_runner = InferenceExperimentRunner(
expt_config, record_memory_snapshot=record_memory_snapshot_cli_arg
)
assert (
expt_config.experiment_settings.record_memory_snapshot
== record_memory_snapshot_cli_arg
)
has_callback = any(
isinstance(cb, MemorySnapshotCallback) for cb in expt_runner.callbacks
)
assert has_callback == record_memory_snapshot_cli_arg

def test_seeding_from_num_seeds(self, dummy_ckpt_file):
expt_config = InferenceExperimentConfig(inference_ckpt_path=dummy_ckpt_file)
num_seeds = 7
Expand Down