-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualization.py
72 lines (63 loc) · 3.1 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import cv2
G = [0, 255, 0]
R = [255, 0, 0]
def convert_to_gray_scale(attributions):
return np.average(attributions, axis=2)
def linear_transform(attributions, clip_above_percentile=99.9, clip_below_percentile=70.0, low=0.2, plot_distribution=False):
m = compute_threshold_by_top_percentage(attributions, percentage=100-clip_above_percentile, plot_distribution=plot_distribution)
e = compute_threshold_by_top_percentage(attributions, percentage=100-clip_below_percentile, plot_distribution=plot_distribution)
transformed = (1 - low) * (np.abs(attributions) - e) / (m - e) + low
transformed *= np.sign(attributions)
transformed *= (transformed >= low)
transformed = np.clip(transformed, 0.0, 1.0)
return transformed
def compute_threshold_by_top_percentage(attributions, percentage=60, plot_distribution=True):
if percentage < 0 or percentage > 100:
raise ValueError('percentage must be in [0, 100]')
if percentage == 100:
return np.min(attributions)
flat_attributions = attributions.flatten()
attribution_sum = np.sum(flat_attributions)
sorted_attributions = np.sort(np.abs(flat_attributions))[::-1]
cum_sum = 100.0 * np.cumsum(sorted_attributions) / attribution_sum
threshold_idx = np.where(cum_sum >= percentage)[0][0]
threshold = sorted_attributions[threshold_idx]
if plot_distribution:
raise NotImplementedError
return threshold
def polarity_function(attributions, polarity):
if polarity == 'positive':
return np.clip(attributions, 0, 1)
elif polarity == 'negative':
return np.clip(attributions, -1, 0)
else:
raise NotImplementedError
def overlay_function(attributions, image):
return np.clip(0.7 * image + 0.5 * attributions, 0, 255)
def visualize(attributions, image, positive_channel=G, negative_channel=R, polarity='positive', \
clip_above_percentile=99.9, clip_below_percentile=0, morphological_cleanup=False, \
structure=np.ones((3, 3)), outlines=False, outlines_component_percentage=90, overlay=True, \
mask_mode=False, plot_distribution=False):
if polarity == 'both':
raise NotImplementedError
elif polarity == 'positive':
attributions = polarity_function(attributions, polarity=polarity)
channel = positive_channel
# convert the attributions to the gray scale
attributions = convert_to_gray_scale(attributions)
attributions = linear_transform(attributions, clip_above_percentile, clip_below_percentile, 0.0, plot_distribution=plot_distribution)
attributions_mask = attributions.copy()
if morphological_cleanup:
raise NotImplementedError
if outlines:
raise NotImplementedError
attributions = np.expand_dims(attributions, 2) * channel
if overlay:
if mask_mode == False:
attributions = overlay_function(attributions, image)
else:
attributions = np.expand_dims(attributions_mask, 2)
attributions = np.clip(attributions * image, 0, 255)
attributions = attributions[:, :, (2, 1, 0)]
return attributions