Skip to content

Commit

Permalink
Prepare data function
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jun 29, 2024
1 parent 27919d1 commit 985a68a
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,25 @@ def setup_tensorboard(cfg, metadata):
return writer


def prepare_data(example, device_id):
"""Prepare data for training or validation."""
position = example[0][0].to(device_id)
particle_type = example[0][1].to(device_id)

if len(example[0]) == 4: # if data loader includes material_property
material_property = example[0][2].to(device_id)
n_particles_per_example = example[0][3].to(device_id)
elif len(example[0]) == 3:
material_property = None
n_particles_per_example = example[0][2].to(device_id)
else:
raise ValueError("Unexpected number of elements in the data loader")

labels = example[1].to(device_id)

return position, particle_type, material_property, n_particles_per_example, labels


def train(rank, cfg, world_size, device, verbose):
"""Train the model.
Expand Down Expand Up @@ -484,9 +503,7 @@ def train(rank, cfg, world_size, device, verbose):
writer = setup_tensorboard(cfg, metadata) if verbose else None

try:
num_epochs = max(
1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl)
) # Calculate total epochs
num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl))
print(f"Total epochs = {num_epochs}")
for epoch in tqdm(range(epoch, num_epochs), desc="Training", unit="epoch"):
if device == torch.device("cuda"):
Expand All @@ -499,16 +516,14 @@ def train(rank, cfg, world_size, device, verbose):
with tqdm(total=len(train_dl), desc=f"Epoch {epoch}", unit="batch") as pbar:
for example in train_dl:
steps_per_epoch += 1
position = example[0][0].to(device_id)
particle_type = example[0][1].to(device_id)
if n_features == 3:
material_property = example[0][2].to(device_id)
n_particles_per_example = example[0][3].to(device_id)
elif n_features == 2:
n_particles_per_example = example[0][2].to(device_id)
else:
raise NotImplementedError
labels = example[1].to(device_id)
# Prepare data
(
position,
particle_type,
material_property,
n_particles_per_example,
labels,
) = prepare_data(example, device_id)

n_particles_per_example = n_particles_per_example.to(device_id)
labels = labels.to(device_id)
Expand Down Expand Up @@ -739,16 +754,10 @@ def _get_simulator(


def validation(simulator, example, n_features, cfg, rank, device_id):
position = example[0][0].to(device_id)
particle_type = example[0][1].to(device_id)
if n_features == 3: # if dl includes material_property
material_property = example[0][2].to(device_id)
n_particles_per_example = example[0][3].to(device_id)
elif n_features == 2:
n_particles_per_example = example[0][2].to(device_id)
else:
raise NotImplementedError
labels = example[1].to(device_id)

position, particle_type, material_property, n_particles_per_example, labels = (
prepare_data(example, device_id)
)

# Sample the noise to add to the inputs.
sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(
Expand Down

0 comments on commit 985a68a

Please sign in to comment.