-
Notifications
You must be signed in to change notification settings - Fork 7
/
identify.py
65 lines (55 loc) · 2.71 KB
/
identify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn.functional as F
n, over_zero = [], []
for lang in ['en', 'zh', 'fr', 'es', 'vi', 'id', 'ja']:
data = torch.load(f'data/activation.{lang}.train.llama-70b')
n.append(data['n'])
over_zero.append(data['over_zero'])
n = torch.tensor(n)
over_zero = torch.stack(over_zero, dim=-1)
num_layers, intermediate_size, lang_num = over_zero.size()
def activation():
top_rate = 0.01
filter_rate = 0.95
activation_bar_ratio = 0.95
activation_probs = over_zero / n # layer x inter x lang_num
normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True)
normed_activation_probs[torch.isnan(normed_activation_probs)] = 0
log_probs = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), 0)
entropy = -torch.sum(normed_activation_probs * log_probs, dim=-1)
largest = False
if torch.isnan(entropy).sum():
print(torch.isnan(entropy).sum())
raise ValueError
flattened_probs = activation_probs.flatten()
top_prob_value = flattened_probs.kthvalue(round(len(flattened_probs) * filter_rate)).values.item()
print(top_prob_value)
# dismiss the neruon if no language has an activation value over top 90%
top_position = (activation_probs > top_prob_value).sum(dim=-1)
entropy[top_position == 0] = -torch.inf if largest else torch.inf
flattened_entropy = entropy.flatten()
top_entropy_value = round(len(flattened_entropy) * top_rate)
_, index = flattened_entropy.topk(top_entropy_value, largest=largest)
row_index = index // entropy.size(1)
col_index = index % entropy.size(1)
selected_probs = activation_probs[row_index, col_index] # n x lang
# for r, c in zip(row_index, col_index):
# print(r, c, activation_probs[r][c])
print(selected_probs.size(0), torch.bincount(selected_probs.argmax(dim=-1)))
selected_probs = selected_probs.transpose(0, 1)
activation_bar = flattened_probs.kthvalue(round(len(flattened_probs) * activation_bar_ratio)).values.item()
print((selected_probs > activation_bar).sum(dim=1).tolist())
lang, indice = torch.where(selected_probs > activation_bar)
merged_index = torch.stack((row_index, col_index), dim=-1)
final_indice = []
for _, index in enumerate(indice.split(torch.bincount(lang).tolist())):
lang_index = [tuple(row.tolist()) for row in merged_index[index]]
lang_index.sort()
layer_index = [[] for _ in range(num_layers)]
for l, h in lang_index:
layer_index[l].append(h)
for l, h in enumerate(layer_index):
layer_index[l] = torch.tensor(h).long()
final_indice.append(layer_index)
torch.save(final_indice, f"activation_mask/llama-70b")
activation()