Skip to content

Commit

Permalink
Merge pull request #342 from netease-youdao/develop_for_v1.3.1
Browse files Browse the repository at this point in the history
Develop for v1.3.1
  • Loading branch information
xixihahaliu authored May 17, 2024
2 parents 5f2e9dc + f795b40 commit e0181b7
Show file tree
Hide file tree
Showing 96 changed files with 14,863 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch
import torch.nn as nn
from .utils import _tranpose_and_gather_feat, _get_wh_feat, _get_4ps_feat, _normalized_ps
import torch.nn.functional as F

import json
import cv2
import os
from .transformer import Transformer
import math
import time
import random
import imgaug.augmenters as iaa
import time
import copy

class Stacker(nn.Module):
def __init__(self, input_size, hidden_size, output_size, layers, heads=8, dropout=0.1):
super(Stacker, self).__init__()
self.logi_encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(inplace=True) #newly added
)
self.tsfm = Transformer(2 * hidden_size, hidden_size, output_size, layers, heads, dropout)

def forward(self, outputs, logi, mask = None, require_att = False):
logi_embeddings = self.logi_encoder(logi)

cat_embeddings = torch.cat((logi_embeddings, outputs), dim=2)

if mask is None:
if require_att:
stacked_axis, att = self.tsfm(cat_embeddings)
else:
stacked_axis = self.tsfm(cat_embeddings)
else:
stacked_axis = self.tsfm(cat_embeddings, mask=mask)

if require_att:
return stacked_axis, att
else:
return stacked_axis

class Processor(nn.Module):
def __init__(self, opt):
super(Processor, self).__init__()

if opt.wiz_stacking:
self.stacker = Stacker(opt.output_size, opt.hidden_size, opt.output_size, opt.stacking_layers)

#input_state, hidden_state, output_state, layers, heads, dropout
self.tsfm_axis = Transformer(opt.input_size, opt.hidden_size, opt.output_size, opt.tsfm_layers, opt.num_heads, opt.att_dropout) #original version
self.x_position_embeddings = nn.Embedding(opt.max_fmp_size, opt.hidden_size)
self.y_position_embeddings = nn.Embedding(opt.max_fmp_size, opt.hidden_size)

self.opt = opt

def forward(self, outputs, dets = None, batch = None, cc_match = None): #training version forward
# 'outputs' stands for the feature of cells
# mask = None
# att = None

'''
Constructing Features:
'''
if batch is None:
# Inference Mode, the four corner features are gathered
# during bounding boxes decoding for simplicity (See ctdet_4ps_decode() in ./src/lib/model/decode.py).

vis_feat = outputs
if dets is None:
feat = vis_feat

else:
left_pe = self.x_position_embeddings(dets[:, :, 0])
upper_pe = self.y_position_embeddings(dets[:, :, 1])
right_pe = self.x_position_embeddings(dets[:, :, 2])
lower_pe = self.y_position_embeddings(dets[:, :, 5])
feat = vis_feat + left_pe + upper_pe + right_pe + lower_pe

# !TODO: moving the processings here and uniform the feature construction code for training and inference.

else:
#Training Mode
ind = batch['hm_ind']
mask = batch['hm_mask'] #during training, the attention mask will be applied
output = outputs[-1]
pred = output['ax']
ct_feat = _tranpose_and_gather_feat(pred, ind)

if self.opt.wiz_2dpe:
cr_feat = _get_4ps_feat(batch['cc_match'], output)
cr_feat = cr_feat.sum(axis = 3)
vis_feat = ct_feat + cr_feat

ps = _get_wh_feat(ind, batch, 'gt')
ps = _normalized_ps(ps, self.opt.max_fmp_size)

left_pe = self.x_position_embeddings(ps[:, :, 0])
upper_pe = self.y_position_embeddings(ps[:, :, 1])
right_pe = self.x_position_embeddings(ps[:, :, 2])
lower_pe = self.y_position_embeddings(ps[:, :, 5])

feat = vis_feat + left_pe + upper_pe + right_pe + lower_pe

elif self.opt.wiz_4ps:
cr_feat = _get_4ps_feat(batch['cc_match'], output)
cr_feat = cr_feat.sum(axis = 3)
feat = ct_feat + cr_feat

elif self.opt.wiz_vanilla:
feat = ct_feat

'''
Put Features into TSFM:
'''

if batch is None:
#Inference Mode
logic_axis = self.tsfm_axis(feat)
if self.opt.wiz_stacking:
stacked_axis = self.stacker(feat, logic_axis)
else:
#Training Mode
logic_axis = self.tsfm_axis(feat, mask = mask)
if self.opt.wiz_stacking:
stacked_axis = self.stacker(feat, logic_axis, mask = mask)

if self.opt.wiz_stacking:
return logic_axis, stacked_axis
else:
return logic_axis

def load_processor(model, model_path, optimizer=None, resume=False, lr=None, lr_step=None):
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict)
return model

def _judge(box):
countx = len(list(set([box[0],box[2],box[4],box[6]])))
county = len(list(set([box[1],box[3],box[5],box[7]])))
if countx<2 or county<2:
return False

return True

Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
from torch.nn.modules import Module
from torch.nn.parallel.scatter_gather import gather
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply


from .scatter_gather import scatter_kwargs

class _DataParallel(Module):
r"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the batch
dimension. In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards
pass, gradients from each replica are summed into the original module.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is the
same size (so that each GPU processes the same number of samples).
See also: :ref:`cuda-nn-dataparallel-instead`
Arbitrary positional and keyword inputs are allowed to be passed into
DataParallel EXCEPT Tensors. All variables will be scattered on dim
specified (default 0). Primitive types will be broadcasted, but all
other types will be a shallow copy and can be corrupted if written to in
the model's forward pass.
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
Example::
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)
"""

# TODO: update notes/cuda.rst when this class handles 8+ GPUs well

def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
super(_DataParallel, self).__init__()

if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return

if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = device_ids
self.chunk_sizes = chunk_sizes
self.output_device = output_device
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])

def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)

def replicate(self, module, device_ids):
return replicate(module, device_ids)

def scatter(self, inputs, kwargs, device_ids, chunk_sizes):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes)

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

def gather(self, outputs, output_device):
return gather(outputs, output_device, dim=self.dim)


def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
This is the functional version of the DataParallel module.
Args:
module: the module to evaluate in parallel
inputs: inputs to the module
device_ids: GPU ids on which to replicate module
output_device: GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
Returns:
a Variable containing the result of module(input) located on
output_device
"""
if not isinstance(inputs, tuple):
inputs = (inputs,)

if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))

if output_device is None:
output_device = device_ids[0]

inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
if len(device_ids) == 1:
return module(*inputs[0], **module_kwargs[0])
used_device_ids = device_ids[:len(inputs)]
replicas = replicate(module, used_device_ids)
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
return gather(outputs, output_device, dim)

def DataParallel(module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
if chunk_sizes is None:
return torch.nn.DataParallel(module, device_ids, output_device, dim)
standard_size = True
for i in range(1, len(chunk_sizes)):
if chunk_sizes[i] != chunk_sizes[0]:
standard_size = False
if standard_size:
return torch.nn.DataParallel(module, device_ids, output_device, dim)
return _DataParallel(module, device_ids, output_device, dim, chunk_sizes)
Loading

0 comments on commit e0181b7

Please sign in to comment.