|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +from torch.nn import Module |
| 4 | + |
| 5 | + |
| 6 | +class UpSampleLayer(Module): |
| 7 | + def __init__(self): |
| 8 | + super().__init__() |
| 9 | + |
| 10 | + @staticmethod |
| 11 | + def calculate_faces(faces, vn): |
| 12 | + edges = {} |
| 13 | + new_faces = [] |
| 14 | + |
| 15 | + def get_edge_id(e): |
| 16 | + if e not in edges: |
| 17 | + edges[e] = len(edges) |
| 18 | + return edges[e] |
| 19 | + |
| 20 | + for f in faces: |
| 21 | + a, b, c = f[0], f[1], f[2] |
| 22 | + e1, e2, e3 = tuple(sorted([a, b])), tuple(sorted([b, c])), tuple(sorted([c, a])) |
| 23 | + x = get_edge_id(e1) + vn |
| 24 | + y = get_edge_id(e2) + vn |
| 25 | + z = get_edge_id(e3) + vn |
| 26 | + new_faces.append(np.array([x, y, z])) |
| 27 | + new_faces.append(np.array([a, x, z])) |
| 28 | + new_faces.append(np.array([b, y, x])) |
| 29 | + new_faces.append(np.array([c, z, y])) |
| 30 | + |
| 31 | + new_faces = np.vstack(new_faces) |
| 32 | + new_vertices_idx = np.vstack([np.array(list(k)) for k in edges.keys()]) |
| 33 | + return new_vertices_idx, new_faces |
| 34 | + |
| 35 | + def forward(self, vertices, faces): |
| 36 | + """ |
| 37 | + * |
| 38 | + / \ |
| 39 | + / \ |
| 40 | + / \ |
| 41 | + * ----- * |
| 42 | + | |
| 43 | + * |
| 44 | + / \ |
| 45 | + o - o |
| 46 | + / \ / \ |
| 47 | + * --o-- * |
| 48 | + """ |
| 49 | + device = vertices.device |
| 50 | + new_vertices_idx_list, new_faces_list = [], [] |
| 51 | + for i, fs in enumerate(faces): |
| 52 | + new_vertices_idx, new_faces = self.calculate_faces(fs.detach().cpu().numpy(), len(vertices[i])) |
| 53 | + new_vertices_idx_list.append(np.expand_dims(new_vertices_idx, axis=0)) |
| 54 | + new_faces_list.append(np.expand_dims(new_faces, axis=0)) |
| 55 | + new_vertices_idx_list = torch.from_numpy(np.vstack(new_vertices_idx_list)).long().to(device) |
| 56 | + new_faces_list = torch.from_numpy(np.vstack(new_faces_list)).long().to(device) |
| 57 | + |
| 58 | + expand_vertices = vertices.unsqueeze(1).expand(-1, new_vertices_idx_list.shape[1], -1, -1) |
| 59 | + expand_vertices_idx = new_vertices_idx_list.unsqueeze(-1).expand(-1, -1, -1, 3) |
| 60 | + new_verts = torch.mean(torch.gather(expand_vertices, 2, expand_vertices_idx), dim=-2) |
| 61 | + new_verts = torch.cat([vertices, new_verts], dim=1) |
| 62 | + return new_verts, new_faces_list |
0 commit comments