Skip to content

Commit 63037ea

Browse files
committed
add upsample layer
1 parent 70a19d2 commit 63037ea

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

manotorch/upsamplelayer.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
import torch
3+
from torch.nn import Module
4+
5+
6+
class UpSampleLayer(Module):
7+
def __init__(self):
8+
super().__init__()
9+
10+
@staticmethod
11+
def calculate_faces(faces, vn):
12+
edges = {}
13+
new_faces = []
14+
15+
def get_edge_id(e):
16+
if e not in edges:
17+
edges[e] = len(edges)
18+
return edges[e]
19+
20+
for f in faces:
21+
a, b, c = f[0], f[1], f[2]
22+
e1, e2, e3 = tuple(sorted([a, b])), tuple(sorted([b, c])), tuple(sorted([c, a]))
23+
x = get_edge_id(e1) + vn
24+
y = get_edge_id(e2) + vn
25+
z = get_edge_id(e3) + vn
26+
new_faces.append(np.array([x, y, z]))
27+
new_faces.append(np.array([a, x, z]))
28+
new_faces.append(np.array([b, y, x]))
29+
new_faces.append(np.array([c, z, y]))
30+
31+
new_faces = np.vstack(new_faces)
32+
new_vertices_idx = np.vstack([np.array(list(k)) for k in edges.keys()])
33+
return new_vertices_idx, new_faces
34+
35+
def forward(self, vertices, faces):
36+
"""
37+
*
38+
/ \
39+
/ \
40+
/ \
41+
* ----- *
42+
|
43+
*
44+
/ \
45+
o - o
46+
/ \ / \
47+
* --o-- *
48+
"""
49+
device = vertices.device
50+
new_vertices_idx_list, new_faces_list = [], []
51+
for i, fs in enumerate(faces):
52+
new_vertices_idx, new_faces = self.calculate_faces(fs.detach().cpu().numpy(), len(vertices[i]))
53+
new_vertices_idx_list.append(np.expand_dims(new_vertices_idx, axis=0))
54+
new_faces_list.append(np.expand_dims(new_faces, axis=0))
55+
new_vertices_idx_list = torch.from_numpy(np.vstack(new_vertices_idx_list)).long().to(device)
56+
new_faces_list = torch.from_numpy(np.vstack(new_faces_list)).long().to(device)
57+
58+
expand_vertices = vertices.unsqueeze(1).expand(-1, new_vertices_idx_list.shape[1], -1, -1)
59+
expand_vertices_idx = new_vertices_idx_list.unsqueeze(-1).expand(-1, -1, -1, 3)
60+
new_verts = torch.mean(torch.gather(expand_vertices, 2, expand_vertices_idx), dim=-2)
61+
new_verts = torch.cat([vertices, new_verts], dim=1)
62+
return new_verts, new_faces_list

0 commit comments

Comments
 (0)