Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions openfold3/core/data/framework/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ class DataModuleConfig(BaseModel):
# Custom worker init function with manual data seed
# (top level to enable pickling working regardless of forking strategy and platform)
def _worker_init_function_with_data_seed(
self,
worker_id: int, rank: int | None = None
self, worker_id: int, rank: int | None = None
) -> None:
"""Modified default Lightning worker_init_fn with manual data seed.

Expand All @@ -182,19 +181,13 @@ def _worker_init_function_with_data_seed(
process_seed = self.data_seed
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
seed_sequence = _generate_seed_sequence(
base_seed, worker_id, global_rank, count=4
)
seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4)
torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed
random.seed(
(seed_sequence[1] << 32) | seed_sequence[2]
) # combine two 64-bit seeds
random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds
if _NUMPY_AVAILABLE:
import numpy as np

np.random.seed(
seed_sequence[3] & 0xFFFFFFFF
) # numpy takes 32-bit seed only
np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only


class DataModule(pl.LightningDataModule):
Expand Down Expand Up @@ -226,7 +219,10 @@ def _initialize_next_dataset_indices(self):

def setup(self, stage=None):
from functools import partial
self.worker_init_function_with_data_seed = partial(_worker_init_function_with_data_seed, self)

self.worker_init_function_with_data_seed = partial(
_worker_init_function_with_data_seed, self
)
self.generator = torch.Generator(device="cpu").manual_seed(self.data_seed)

self.datasets_by_mode = {k: [] for k in DatasetMode}
Expand Down Expand Up @@ -447,7 +443,9 @@ def generate_dataloader(self, mode: DatasetMode):
generator=self.generator,
worker_init_fn=self.worker_init_function_with_data_seed,
# https://github.com/pytorch/pytorch/issues/87688
multiprocessing_context = "fork" if torch.backends.mps.is_available() and num_workers else None
multiprocessing_context="fork"
if torch.backends.mps.is_available() and num_workers
else None,
)

def train_dataloader(self) -> DataLoader:
Expand Down
45 changes: 1 addition & 44 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
# | H100 cuda13 | to test | to test | to test | to test |
# | B300 cuda12 | works | works | to test | to test |
# | B300 cuda13 | works | works | to test | to test |
# | GB10 cuda12 | n/a | n/a | works | works |
# | GB10 cuda13 | n/a | n/a | works | works |
# | CPU | works | n/a | to test | n/a |
# +---------------+------------+------------+------------+------------+
# * works = "works for me"
Expand Down Expand Up @@ -368,7 +370,7 @@ torch = { version = "*", index = "https://download.pytorch.org/whl/cu130" }
# 2.1) No aarch wheel for cuequivariance-ops-torch-cu13 <0.8
#

[feature.cuequivariance-cuda12.pypi-dependencies]
[feature.cuequivariance-cuda12.target.linux-64.pypi-dependencies]
cuequivariance = "<0.8"
cuequivariance-torch = "<0.8"
cuequivariance-ops-torch-cu12 = "<0.8"
Expand Down