Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse Head implementation #4

Merged
merged 7 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"-NIR",
"-S1",
"-treg", "rwa",
"-tregtrain", "rwa",
"-tregtrain", "rwa",
"--seed", "1600",
"-occmodel",
"-wd", "0.0000005",
Expand Down
58 changes: 46 additions & 12 deletions model/popcorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,25 @@ def forward(self, inputs, train=False, padding=True, return_features=True,

headin = torch.cat(middlefeatures, dim=1)

# forward the head, TODO: make this module sparse
out = self.head(headin)[:,0]
# forward the head
if sparse:
out = self.sparse_module_forward(headin, sparsity_mask, self.head, out_channels=2)[:,0]
else:
out = self.head(headin)[:,0]

# Population map and total count
if self.occupancymodel:

# activation function for the population map is a ReLU to avoid negative values
scale = nn.functional.relu(out)

if "building_counts" in inputs.keys():

# save the scale
if sparse:
aux["scale"] = scale[sparsity_mask]
else:
aux["scale"] = scale

# Get the population density map
popdensemap = scale * inputs["building_counts"][:,0]
else:
raise ValueError("building_counts not in inputs.keys(), but occupancy model is True")
# Get the population density map
popdensemap = scale * inputs["building_counts"][:,0]
else:
popdensemap = nn.functional.relu(out)
aux["scale"] = None
Expand All @@ -189,9 +190,43 @@ def forward(self, inputs, train=False, padding=True, return_features=True,
popcount = popdensemap.sum((1,2))

return {"popcount": popcount, "popdensemap": popdensemap,
**aux,
}
**aux }

def sparse_module_forward(self, inp: torch.Tensor, mask: torch.Tensor,
module: callable, out_channels=2) -> torch.Tensor:
"""
Description:
- Perform a forward pass with a module on a sparse input
Input:
- inp (torch.Tensor): input data
- mask (torch.Tensor): mask of the input data
- module (torch.nn.Module): module to apply
- out_channels (int): number of output channels
Output:
- out (torch.Tensor): output data
"""
# Validate input shape
if len(inp.shape) != 4:
raise ValueError("Input tensor must have shape (batch_size, channels, height, width)")

# bring everything together
batch_size, channels, height, width = inp.shape
inp_flat = inp.permute(1,0,2,3).contiguous().view(channels, -1, 1)

# flatten mask
mask_flat = mask.view(-1)

# initialize the output
out_flat = torch.zeros((out_channels, batch_size*height*width,1), device=inp.device, dtype=inp.dtype)

# form together the output
out_flat[ :, mask_flat] = module(inp_flat[:, mask_flat])

# reshape the output
out = out_flat.view(out_channels, batch_size, height, width).permute(1,0,2,3)

return out


def add_padding(self, data: torch.Tensor, force=True) -> torch.Tensor:
"""
Expand Down Expand Up @@ -285,7 +320,6 @@ def create_building_score(self, inputs: dict) -> torch.Tensor:
score = self.revert_padding(score, (px1,px2,py1,py2))

return score



def get_sparsity_mask(self, inputs: torch.Tensor, sparse_unet=False) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def train(self):

with tqdm(range(self.info["epoch"], self.args.num_epochs), leave=True) as tnr:
tnr.set_postfix(training_loss=np.nan, validation_loss=np.nan, best_validation_loss=np.nan)

for _ in tnr:

self.train_epoch(tnr)
Expand Down Expand Up @@ -444,7 +445,7 @@ def get_dataloaders(self, args: argparse.Namespace) -> dict:
def save_model(self, prefix=''):
"""
Input:
prefix: string to prepend to the filename
prefix: string to prepend to the filename
"""
torch.save({
'model': self.model.state_dict(),
Expand Down