-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathProsody_Visualization.py
140 lines (111 loc) · 4.44 KB
/
Prosody_Visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from tqdm import tqdm
from sklearn.manifold import TSNE
from train import load_model
from text import kor_text_to_phoneme, kr_phoneme_symbols
from hparams import create_hparams
class Dataset(torch.utils.data.Dataset):
def __init__(self, ref_mel_paths):
super(Dataset, self).__init__()
self.path_list = [
(os.path.join(root, file), dataset)
for paths, dataset in ref_mel_paths
for root, _, files in os.walk(paths)
for file in files
if os.path.splitext(file)[1].lower() == '.npy'
]
def __getitem__(self, index):
path, dataset = self.path_list[index]
return np.load(path), dataset
def __len__(self):
return len(self.path_list)
class Collate:
def __call__(self, batch):
mels, datasets = zip(*batch)
lengths = [mel.shape[0] for mel in mels]
max_length = max(lengths)
mels = np.stack([
np.pad(mel, [[0, max_length - mel.shape[0]],[0,0]])
for mel in mels
])
mels = torch.FloatTensor(mels).transpose(2, 1)
lengths = torch.LongTensor(lengths)
return mels, lengths, datasets
def getColor(c, N, idx):
import matplotlib as mpl
cmap = mpl.cm.get_cmap(c)
norm = mpl.colors.Normalize(vmin=0.0, vmax=N - 1)
return cmap(norm(idx))
kr_phoneme_symbols = {key: value for key, value in enumerate(kr_phoneme_symbols.phoneme_symbols)}
hparams = create_hparams()
checkpoint_path = '/home/heejo/Documents/TMAX/Temp/tsd1_emo_gst/checkpoint_171000'
export_path = os.path.join(
os.path.dirname(checkpoint_path),
'{}{}'.format('R', os.path.basename(checkpoint_path).split('_')[1][:-3])
)
ref_mel_paths= [
('/home/heejo/data/tsd1_prep1', 'TSD1'),
('/home/heejo/data/emo_data/ada_mel', 'EMO_ADA'),
('/home/heejo/data/emo_data/adb_mel', 'EMO_ADB'),
('/home/heejo/data/emo_data/adc_mel', 'EMO_ADC'),
('/home/heejo/data/emo_data/add_mel', 'EMO_ADD'),
('/home/heejo/data/emo_data/ava_mel', 'EMO_AVA'),
('/home/heejo/data/emo_data/avb_mel', 'EMO_AVB'),
('/home/heejo/data/emo_data/avc_mel', 'EMO_AVC'),
('/home/heejo/data/emo_data/avd_mel', 'EMO_AVD'),
('/home/heejo/data/emo_data/ema_mel', 'EMO_EMA'),
('/home/heejo/data/emo_data/emb_mel', 'EMO_EMB'),
('/home/heejo/data/emo_data/emf_mel', 'EMO_EMF'),
('/home/heejo/data/emo_data/emg_mel', 'EMO_EMG'),
('/home/heejo/data/emo_data/emh_mel', 'EMO_EMH'),
('/home/heejo/data/emo_data/lmy_mel', 'EMO_LMY'),
('/home/heejo/data/emo_data/nea_mel', 'EMO_NEA'),
('/home/heejo/data/emo_data/neb_mel', 'EMO_NEB'),
('/home/heejo/data/emo_data/nec_mel', 'EMO_NEC'),
('/home/heejo/data/emo_data/ned_mel', 'EMO_NED'),
('/home/heejo/data/emo_data/nee_mel', 'EMO_NEE'),
('/home/heejo/data/emo_data/nek_mel', 'EMO_NEK'),
('/home/heejo/data/emo_data/nel_mel', 'EMO_NEL'),
('/home/heejo/data/emo_data/nem_mel', 'EMO_NEM'),
('/home/heejo/data/emo_data/nen_mel', 'EMO_NEN'),
('/home/heejo/data/emo_data/neo_mel', 'EMO_ENO'),
]
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
model = model.gst.cuda()
model.eval().half()
loader = torch.utils.data.DataLoader(
dataset= Dataset(ref_mel_paths),
num_workers=2,
shuffle= False,
batch_size=hparams.batch_size,
pin_memory=False,
collate_fn=Collate()
)
embeddings, datasets = zip(*[
(model(mels.cuda().half(), lengths.cuda()).cpu().detach().numpy().astype(np.float32), datasets)
for mels, lengths, datasets in tqdm(loader)
])
# embeddings = np.squeeze(np.vstack(embeddings), axis= 1)
embeddings = np.vstack(embeddings)
datasets = [dataset for sub in datasets for dataset in sub]
scatters = TSNE(n_components=2, random_state= 0).fit_transform(embeddings)
fig = plt.figure(figsize=(12, 12))
for index, (_, dataset) in enumerate(ref_mel_paths):
sub_scatters =[scatter for scatter, scatter_datset in zip(scatters, datasets) if dataset == scatter_datset]
if len(sub_scatters) == 0:
continue
sub_scatters = np.stack(sub_scatters)
plt.scatter(
sub_scatters[:, 0],
sub_scatters[:, 1],
c= np.array([getColor('gist_ncar', len(ref_mel_paths), index)] * len(sub_scatters)),
label= dataset
)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(export_path, 'gst_tsne.png'))
plt.close()