Skip to content

Commit 947fd58

Browse files
committed
model added
1 parent a4f2425 commit 947fd58

20 files changed

+2122
-1
lines changed

v2xvit/models/__init__.py

Whitespace-only changes.
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from v2xvit.models.sub_modules.pillar_vfe import PillarVFE
5+
from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter
6+
from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone
7+
from v2xvit.models.sub_modules.fuse_utils import regroup
8+
from v2xvit.models.sub_modules.downsample_conv import DownsampleConv
9+
from v2xvit.models.sub_modules.naive_compress import NaiveCompressor
10+
from v2xvit.models.sub_modules.v2xvit_basic import V2XTransformer
11+
12+
13+
class PointPillarTransformer(nn.Module):
14+
def __init__(self, args):
15+
super(PointPillarTransformer, self).__init__()
16+
17+
self.max_cav = args['max_cav']
18+
# PIllar VFE
19+
self.pillar_vfe = PillarVFE(args['pillar_vfe'],
20+
num_point_features=4,
21+
voxel_size=args['voxel_size'],
22+
point_cloud_range=args['lidar_range'])
23+
self.scatter = PointPillarScatter(args['point_pillar_scatter'])
24+
self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
25+
# used to downsample the feature map for efficient computation
26+
self.shrink_flag = False
27+
if 'shrink_header' in args:
28+
self.shrink_flag = True
29+
self.shrink_conv = DownsampleConv(args['shrink_header'])
30+
self.compression = False
31+
32+
if args['compression'] > 0:
33+
self.compression = True
34+
self.naive_compressor = NaiveCompressor(256, args['compression'])
35+
36+
self.fusion_net = V2XTransformer(args['transformer'])
37+
38+
self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'],
39+
kernel_size=1)
40+
self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'],
41+
kernel_size=1)
42+
43+
if args['backbone_fix']:
44+
self.backbone_fix()
45+
46+
def backbone_fix(self):
47+
"""
48+
Fix the parameters of backbone during finetune on timedelay。
49+
"""
50+
for p in self.pillar_vfe.parameters():
51+
p.requires_grad = False
52+
53+
for p in self.scatter.parameters():
54+
p.requires_grad = False
55+
56+
for p in self.backbone.parameters():
57+
p.requires_grad = False
58+
59+
if self.compression:
60+
for p in self.naive_compressor.parameters():
61+
p.requires_grad = False
62+
if self.shrink_flag:
63+
for p in self.shrink_conv.parameters():
64+
p.requires_grad = False
65+
66+
for p in self.cls_head.parameters():
67+
p.requires_grad = False
68+
for p in self.reg_head.parameters():
69+
p.requires_grad = False
70+
71+
def forward(self, data_dict):
72+
voxel_features = data_dict['processed_lidar']['voxel_features']
73+
voxel_coords = data_dict['processed_lidar']['voxel_coords']
74+
voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
75+
record_len = data_dict['record_len']
76+
spatial_correction_matrix = data_dict['spatial_correction_matrix']
77+
78+
# B, max_cav, 3(dt dv infra), 1, 1
79+
prior_encoding =\
80+
data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1)
81+
82+
batch_dict = {'voxel_features': voxel_features,
83+
'voxel_coords': voxel_coords,
84+
'voxel_num_points': voxel_num_points,
85+
'record_len': record_len}
86+
# n, 4 -> n, c
87+
batch_dict = self.pillar_vfe(batch_dict)
88+
# n, c -> N, C, H, W
89+
batch_dict = self.scatter(batch_dict)
90+
batch_dict = self.backbone(batch_dict)
91+
92+
spatial_features_2d = batch_dict['spatial_features_2d']
93+
# downsample feature to reduce memory
94+
if self.shrink_flag:
95+
spatial_features_2d = self.shrink_conv(spatial_features_2d)
96+
# compressor
97+
if self.compression:
98+
spatial_features_2d = self.naive_compressor(spatial_features_2d)
99+
# N, C, H, W -> B, L, C, H, W
100+
regroup_feature, mask = regroup(spatial_features_2d,
101+
record_len,
102+
self.max_cav)
103+
# prior encoding added
104+
prior_encoding = prior_encoding.repeat(1, 1, 1,
105+
regroup_feature.shape[3],
106+
regroup_feature.shape[4])
107+
regroup_feature = torch.cat([regroup_feature, prior_encoding], dim=2)
108+
109+
# b l c h w -> b l h w c
110+
regroup_feature = regroup_feature.permute(0, 1, 3, 4, 2)
111+
# transformer fusion
112+
fused_feature = self.fusion_net(regroup_feature, mask, spatial_correction_matrix)
113+
# b h w c -> b c h w
114+
fused_feature = fused_feature.permute(0, 3, 1, 2)
115+
116+
psm = self.cls_head(fused_feature)
117+
rm = self.reg_head(fused_feature)
118+
119+
output_dict = {'psm': psm,
120+
'rm': rm}
121+
122+
return output_dict

v2xvit/models/sub_modules/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
class BaseBEVBackbone(nn.Module):
7+
def __init__(self, model_cfg, input_channels):
8+
super().__init__()
9+
self.model_cfg = model_cfg
10+
11+
if 'layer_nums' in self.model_cfg:
12+
13+
assert len(self.model_cfg['layer_nums']) == \
14+
len(self.model_cfg['layer_strides']) == \
15+
len(self.model_cfg['num_filters'])
16+
17+
layer_nums = self.model_cfg['layer_nums']
18+
layer_strides = self.model_cfg['layer_strides']
19+
num_filters = self.model_cfg['num_filters']
20+
else:
21+
layer_nums = layer_strides = num_filters = []
22+
23+
if 'upsample_strides' in self.model_cfg:
24+
assert len(self.model_cfg['upsample_strides']) \
25+
== len(self.model_cfg['num_upsample_filter'])
26+
27+
num_upsample_filters = self.model_cfg['num_upsample_filter']
28+
upsample_strides = self.model_cfg['upsample_strides']
29+
30+
else:
31+
upsample_strides = num_upsample_filters = []
32+
33+
num_levels = len(layer_nums)
34+
c_in_list = [input_channels, *num_filters[:-1]]
35+
36+
self.blocks = nn.ModuleList()
37+
self.deblocks = nn.ModuleList()
38+
39+
for idx in range(num_levels):
40+
cur_layers = [
41+
nn.ZeroPad2d(1),
42+
nn.Conv2d(
43+
c_in_list[idx], num_filters[idx], kernel_size=3,
44+
stride=layer_strides[idx], padding=0, bias=False
45+
),
46+
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
47+
nn.ReLU()
48+
]
49+
for k in range(layer_nums[idx]):
50+
cur_layers.extend([
51+
nn.Conv2d(num_filters[idx], num_filters[idx],
52+
kernel_size=3, padding=1, bias=False),
53+
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
54+
nn.ReLU()
55+
])
56+
57+
self.blocks.append(nn.Sequential(*cur_layers))
58+
if len(upsample_strides) > 0:
59+
stride = upsample_strides[idx]
60+
if stride >= 1:
61+
self.deblocks.append(nn.Sequential(
62+
nn.ConvTranspose2d(
63+
num_filters[idx], num_upsample_filters[idx],
64+
upsample_strides[idx],
65+
stride=upsample_strides[idx], bias=False
66+
),
67+
nn.BatchNorm2d(num_upsample_filters[idx],
68+
eps=1e-3, momentum=0.01),
69+
nn.ReLU()
70+
))
71+
else:
72+
stride = np.round(1 / stride).astype(np.int)
73+
self.deblocks.append(nn.Sequential(
74+
nn.Conv2d(
75+
num_filters[idx], num_upsample_filters[idx],
76+
stride,
77+
stride=stride, bias=False
78+
),
79+
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3,
80+
momentum=0.01),
81+
nn.ReLU()
82+
))
83+
84+
c_in = sum(num_upsample_filters)
85+
if len(upsample_strides) > num_levels:
86+
self.deblocks.append(nn.Sequential(
87+
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1],
88+
stride=upsample_strides[-1], bias=False),
89+
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
90+
nn.ReLU(),
91+
))
92+
93+
self.num_bev_features = c_in
94+
95+
def forward(self, data_dict):
96+
spatial_features = data_dict['spatial_features']
97+
98+
ups = []
99+
ret_dict = {}
100+
x = spatial_features
101+
102+
for i in range(len(self.blocks)):
103+
x = self.blocks[i](x)
104+
105+
stride = int(spatial_features.shape[2] / x.shape[2])
106+
ret_dict['spatial_features_%dx' % stride] = x
107+
108+
if len(self.deblocks) > 0:
109+
ups.append(self.deblocks[i](x))
110+
else:
111+
ups.append(x)
112+
113+
if len(ups) > 1:
114+
x = torch.cat(ups, dim=1)
115+
elif len(ups) == 1:
116+
x = ups[0]
117+
118+
if len(self.deblocks) > len(self.blocks):
119+
x = self.deblocks[-1](x)
120+
121+
data_dict['spatial_features_2d'] = x
122+
return data_dict
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange
5+
6+
7+
class PreNorm(nn.Module):
8+
def __init__(self, dim, fn):
9+
super().__init__()
10+
self.norm = nn.LayerNorm(dim)
11+
self.fn = fn
12+
13+
def forward(self, x, **kwargs):
14+
return self.fn(self.norm(x), **kwargs)
15+
16+
17+
class FeedForward(nn.Module):
18+
def __init__(self, dim, hidden_dim, dropout=0.):
19+
super().__init__()
20+
self.net = nn.Sequential(
21+
nn.Linear(dim, hidden_dim),
22+
nn.GELU(),
23+
nn.Dropout(dropout),
24+
nn.Linear(hidden_dim, dim),
25+
nn.Dropout(dropout)
26+
)
27+
28+
def forward(self, x):
29+
return self.net(x)
30+
31+
32+
class CavAttention(nn.Module):
33+
"""
34+
Vanilla CAV attention.
35+
"""
36+
def __init__(self, dim, heads, dim_head=64, dropout=0.1):
37+
super().__init__()
38+
inner_dim = heads * dim_head
39+
40+
self.heads = heads
41+
self.scale = dim_head ** -0.5
42+
43+
self.attend = nn.Softmax(dim=-1)
44+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
45+
46+
self.to_out = nn.Sequential(
47+
nn.Linear(inner_dim, dim),
48+
nn.Dropout(dropout)
49+
)
50+
51+
def forward(self, x, mask, prior_encoding):
52+
# x: (B, L, H, W, C) -> (B, H, W, L, C)
53+
# mask: (B, L)
54+
x = x.permute(0, 2, 3, 1, 4)
55+
# mask: (B, 1, H, W, L, 1)
56+
mask = mask.unsqueeze(1)
57+
58+
# qkv: [(B, H, W, L, C_inner) *3]
59+
qkv = self.to_qkv(x).chunk(3, dim=-1)
60+
# q: (B, M, H, W, L, C)
61+
q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c',
62+
m=self.heads), qkv)
63+
64+
# attention, (B, M, H, W, L, L)
65+
att_map = torch.einsum('b m h w i c, b m h w j c -> b m h w i j',
66+
q, k) * self.scale
67+
# add mask
68+
att_map = att_map.masked_fill(mask == 0, -float('inf'))
69+
# softmax
70+
att_map = self.attend(att_map)
71+
72+
# out:(B, M, H, W, L, C_head)
73+
out = torch.einsum('b m h w i j, b m h w j c -> b m h w i c', att_map,
74+
v)
75+
out = rearrange(out, 'b m h w l c -> b h w l (m c)',
76+
m=self.heads)
77+
out = self.to_out(out)
78+
# (B L H W C)
79+
out = out.permute(0, 3, 1, 2, 4)
80+
return out
81+
82+
83+
class BaseEncoder(nn.Module):
84+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
85+
super().__init__()
86+
self.layers = nn.ModuleList([])
87+
for _ in range(depth):
88+
self.layers.append(nn.ModuleList([
89+
PreNorm(dim, CavAttention(dim,
90+
heads=heads,
91+
dim_head=dim_head,
92+
dropout=dropout)),
93+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
94+
]))
95+
96+
def forward(self, x, mask):
97+
for attn, ff in self.layers:
98+
x = attn(x, mask=mask) + x
99+
x = ff(x) + x
100+
return x
101+
102+
103+
class BaseTransformer(nn.Module):
104+
def __init__(self, args):
105+
super().__init__()
106+
107+
dim = args['dim']
108+
depth = args['depth']
109+
heads = args['heads']
110+
dim_head = args['dim_head']
111+
mlp_dim = args['mlp_dim']
112+
dropout = args['dropout']
113+
max_cav = args['max_cav']
114+
115+
self.encoder = BaseEncoder(dim, depth, heads, dim_head, mlp_dim,
116+
dropout)
117+
118+
def forward(self, x, mask):
119+
# B, L, H, W, C
120+
output = self.encoder(x, mask)
121+
# B, H, W, C
122+
output = output[:, 0]
123+
124+
return output

0 commit comments

Comments
 (0)