diff --git a/.vscode/launch.json b/.vscode/launch.json index 92d553d..4701584 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -30,7 +30,7 @@ "-NIR", "-S1", "-treg", "rwa", - "-tregtrain", "rwa", + "-tregtrain", "rwa", "--seed", "1600", "-occmodel", "-wd", "0.0000005", diff --git a/model/popcorn.py b/model/popcorn.py index 51f0bfd..dda8541 100644 --- a/model/popcorn.py +++ b/model/popcorn.py @@ -157,8 +157,11 @@ def forward(self, inputs, train=False, padding=True, return_features=True, headin = torch.cat(middlefeatures, dim=1) - # forward the head, TODO: make this module sparse - out = self.head(headin)[:,0] + # forward the head + if sparse: + out = self.sparse_module_forward(headin, sparsity_mask, self.head, out_channels=2)[:,0] + else: + out = self.head(headin)[:,0] # Population map and total count if self.occupancymodel: @@ -166,15 +169,13 @@ def forward(self, inputs, train=False, padding=True, return_features=True, # activation function for the population map is a ReLU to avoid negative values scale = nn.functional.relu(out) - if "building_counts" in inputs.keys(): - - # save the scale + if sparse: + aux["scale"] = scale[sparsity_mask] + else: aux["scale"] = scale - # Get the population density map - popdensemap = scale * inputs["building_counts"][:,0] - else: - raise ValueError("building_counts not in inputs.keys(), but occupancy model is True") + # Get the population density map + popdensemap = scale * inputs["building_counts"][:,0] else: popdensemap = nn.functional.relu(out) aux["scale"] = None @@ -189,9 +190,43 @@ def forward(self, inputs, train=False, padding=True, return_features=True, popcount = popdensemap.sum((1,2)) return {"popcount": popcount, "popdensemap": popdensemap, - **aux, - } + **aux } + + def sparse_module_forward(self, inp: torch.Tensor, mask: torch.Tensor, + module: callable, out_channels=2) -> torch.Tensor: + """ + Description: + - Perform a forward pass with a module on a sparse input + Input: + - inp (torch.Tensor): input data + - mask (torch.Tensor): mask of the input data + - module (torch.nn.Module): module to apply + - out_channels (int): number of output channels + Output: + - out (torch.Tensor): output data + """ + # Validate input shape + if len(inp.shape) != 4: + raise ValueError("Input tensor must have shape (batch_size, channels, height, width)") + + # bring everything together + batch_size, channels, height, width = inp.shape + inp_flat = inp.permute(1,0,2,3).contiguous().view(channels, -1, 1) + + # flatten mask + mask_flat = mask.view(-1) + + # initialize the output + out_flat = torch.zeros((out_channels, batch_size*height*width,1), device=inp.device, dtype=inp.dtype) + # form together the output + out_flat[ :, mask_flat] = module(inp_flat[:, mask_flat]) + + # reshape the output + out = out_flat.view(out_channels, batch_size, height, width).permute(1,0,2,3) + + return out + def add_padding(self, data: torch.Tensor, force=True) -> torch.Tensor: """ @@ -285,7 +320,6 @@ def create_building_score(self, inputs: dict) -> torch.Tensor: score = self.revert_padding(score, (px1,px2,py1,py2)) return score - def get_sparsity_mask(self, inputs: torch.Tensor, sparse_unet=False) -> torch.Tensor: diff --git a/run_train.py b/run_train.py index 1496951..4be695d 100644 --- a/run_train.py +++ b/run_train.py @@ -114,6 +114,7 @@ def train(self): with tqdm(range(self.info["epoch"], self.args.num_epochs), leave=True) as tnr: tnr.set_postfix(training_loss=np.nan, validation_loss=np.nan, best_validation_loss=np.nan) + for _ in tnr: self.train_epoch(tnr) @@ -444,7 +445,7 @@ def get_dataloaders(self, args: argparse.Namespace) -> dict: def save_model(self, prefix=''): """ Input: - prefix: string to prepend to the filename + prefix: string to prepend to the filename """ torch.save({ 'model': self.model.state_dict(),