Skip to content

Commit

Permalink
Fix data fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent c9c2b59 commit 4a76e49
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,23 +1097,52 @@ def validation(simulator, example, n_features, cfg, rank, device_id):
return loss

def get_batch_for_material(train_dl, target_material_id, device_id):
"""
Randomly search through the dataloader until a batch with the target material ID is found.
"""
dataset = train_dl.dataset
indices = list(range(len(dataset)))
random.shuffle(indices)

for idx in indices:
batch = dataset[idx]
_, _, material_property, _, _ = prepare_data([t.unsqueeze(0) for t in batch], device_id)
# batch is now (features, label)
features, label = batch

if material_property.numel() == 1:
batch_material_id = material_property.item()
# Unpack features
if len(features) == 4: # If material property is present
positions, particle_type, material_property, n_particles_per_example = features
else:
positions, particle_type, n_particles_per_example = features
material_property = None

# Check material property
if material_property is not None:
if isinstance(material_property, np.ndarray) and material_property.size > 0:
batch_material_id = material_property[0]
else:
batch_material_id = material_property
else:
batch_material_id = material_property[0].item()
# If material property is not present, we can't match it
continue

if batch_material_id == target_material_id:
return [t.unsqueeze(0) for t in batch]
# Convert numpy arrays to tensors
positions = torch.from_numpy(positions).float().unsqueeze(0)
particle_type = torch.from_numpy(particle_type).long().unsqueeze(0)
if material_property is not None:
material_property = torch.from_numpy(np.array(material_property)).float().unsqueeze(0)
n_particles_per_example = torch.tensor([n_particles_per_example]).long()
label = torch.from_numpy(label).float().unsqueeze(0)

# Recreate the batch structure
features = (positions, particle_type, material_property, n_particles_per_example) if material_property is not None else (positions, particle_type, n_particles_per_example)
return (features, label)

return [t.unsqueeze(0) for t in dataset[random.choice(indices)]]
# If we've gone through the entire dataset without finding the target material,
# we'll just return a random batch (this is to handle edge cases)
random_idx = random.choice(indices)
return dataset[random_idx]

def get_unique_material_ids(train_dl, device_id):
unique_ids = set()
Expand Down Expand Up @@ -1204,7 +1233,6 @@ def train_reptile(rank, cfg, world_size, device, verbose, use_dist):
for _ in range(inner_steps):
# Prepare data
position, particle_type, material_property, n_particles_per_example, labels = prepare_data(example, device_id)

n_particles_per_example = n_particles_per_example.to(device_id)
labels = labels.to(device_id)

Expand Down

0 comments on commit 4a76e49

Please sign in to comment.