From 4a76e497846d994b18f70fe74fb707e3b382d1d2 Mon Sep 17 00:00:00 2001 From: Krishna Kumar Date: Mon, 15 Jul 2024 09:26:32 -0600 Subject: [PATCH] Fix data fetch --- gns/train.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/gns/train.py b/gns/train.py index 519176e..9701a2e 100644 --- a/gns/train.py +++ b/gns/train.py @@ -1097,23 +1097,52 @@ def validation(simulator, example, n_features, cfg, rank, device_id): return loss def get_batch_for_material(train_dl, target_material_id, device_id): + """ + Randomly search through the dataloader until a batch with the target material ID is found. + """ dataset = train_dl.dataset indices = list(range(len(dataset))) random.shuffle(indices) for idx in indices: batch = dataset[idx] - _, _, material_property, _, _ = prepare_data([t.unsqueeze(0) for t in batch], device_id) + # batch is now (features, label) + features, label = batch - if material_property.numel() == 1: - batch_material_id = material_property.item() + # Unpack features + if len(features) == 4: # If material property is present + positions, particle_type, material_property, n_particles_per_example = features + else: + positions, particle_type, n_particles_per_example = features + material_property = None + + # Check material property + if material_property is not None: + if isinstance(material_property, np.ndarray) and material_property.size > 0: + batch_material_id = material_property[0] + else: + batch_material_id = material_property else: - batch_material_id = material_property[0].item() + # If material property is not present, we can't match it + continue if batch_material_id == target_material_id: - return [t.unsqueeze(0) for t in batch] + # Convert numpy arrays to tensors + positions = torch.from_numpy(positions).float().unsqueeze(0) + particle_type = torch.from_numpy(particle_type).long().unsqueeze(0) + if material_property is not None: + material_property = torch.from_numpy(np.array(material_property)).float().unsqueeze(0) + n_particles_per_example = torch.tensor([n_particles_per_example]).long() + label = torch.from_numpy(label).float().unsqueeze(0) + + # Recreate the batch structure + features = (positions, particle_type, material_property, n_particles_per_example) if material_property is not None else (positions, particle_type, n_particles_per_example) + return (features, label) - return [t.unsqueeze(0) for t in dataset[random.choice(indices)]] + # If we've gone through the entire dataset without finding the target material, + # we'll just return a random batch (this is to handle edge cases) + random_idx = random.choice(indices) + return dataset[random_idx] def get_unique_material_ids(train_dl, device_id): unique_ids = set() @@ -1204,7 +1233,6 @@ def train_reptile(rank, cfg, world_size, device, verbose, use_dist): for _ in range(inner_steps): # Prepare data position, particle_type, material_property, n_particles_per_example, labels = prepare_data(example, device_id) - n_particles_per_example = n_particles_per_example.to(device_id) labels = labels.to(device_id)