Skip to content

Commit

Permalink
GNS train get batch
Browse files Browse the repository at this point in the history
  • Loading branch information
kks32 committed Jul 15, 2024
1 parent 4a76e49 commit edc9638
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,38 +1106,20 @@ def get_batch_for_material(train_dl, target_material_id, device_id):

for idx in indices:
batch = dataset[idx]
# batch is now (features, label)
features, label = batch
features, _ = batch

# Unpack features
if len(features) == 4: # If material property is present
positions, particle_type, material_property, n_particles_per_example = features
_, _, material_property, _ = features
else:
positions, particle_type, n_particles_per_example = features
material_property = None
continue # Skip if material property is not present

# Check material property
if material_property is not None:
if isinstance(material_property, np.ndarray) and material_property.size > 0:
batch_material_id = material_property[0]
else:
batch_material_id = material_property
if isinstance(material_property, np.ndarray) and material_property.size > 0:
batch_material_id = material_property[0]
else:
# If material property is not present, we can't match it
continue
batch_material_id = material_property

if batch_material_id == target_material_id:
# Convert numpy arrays to tensors
positions = torch.from_numpy(positions).float().unsqueeze(0)
particle_type = torch.from_numpy(particle_type).long().unsqueeze(0)
if material_property is not None:
material_property = torch.from_numpy(np.array(material_property)).float().unsqueeze(0)
n_particles_per_example = torch.tensor([n_particles_per_example]).long()
label = torch.from_numpy(label).float().unsqueeze(0)

# Recreate the batch structure
features = (positions, particle_type, material_property, n_particles_per_example) if material_property is not None else (positions, particle_type, n_particles_per_example)
return (features, label)
return batch

# If we've gone through the entire dataset without finding the target material,
# we'll just return a random batch (this is to handle edge cases)
Expand Down

0 comments on commit edc9638

Please sign in to comment.