-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
110 lines (90 loc) · 4.05 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from matplotlib import pyplot as plt
import numpy as np
import os
import sklearn
import sklearn.cluster
import scipy.cluster
import seaborn as sns
labels_and_points = [] # a hack to get around namespace problems
PALETTE = np.array(sns.color_palette('deep'))
def mds(df, name="", dim=2, metric=True, clustering=True, clusters=4, interactive=False,
cmap='Set1'):
"""Saves a scatterplot of the items projected onto 2 dimensions.
Uses MDS to project items onto a 2 or 3 dimensional based on their
distances from each other. If items is not None, only plot those
that are given. If interactive is truthy, an interactive plot will pop up. This is
recommended for 3D graphs which are hard to make sense of without
rotating the graph.
"""
items = df.index
plt.clf()
if clustering:
clustering = sklearn.cluster.AgglomerativeClustering(
linkage='complete', affinity='precomputed', n_clusters=clusters)
assignments = clustering.fit_predict(df)
if dim == 2:
mds = sklearn.manifold.MDS(n_components=2, metric=metric, eps=1e-9, dissimilarity="precomputed")
points = mds.fit(df).embedding_
assignments += 1
plt.scatter(points[:,0], points[:,1], c=PALETTE[assignments], s=40)
for label, x, y in zip(items, points[:, 0], points[:, 1]):
plt.annotate(label, xy = (x, y), xytext = (-5, 5),
textcoords = 'offset points', ha = 'right', va = 'bottom')
plt.xticks([])
plt.yticks([])
plt.gca().set_aspect('equal', 'datalim')
else:
if dim is not 3:
raise ValueError('dim must be 2 or 3. {} provided'.format(dim))
from mpl_toolkits.mplot3d import Axes3D # used implicitly
from mpl_toolkits.mplot3d import proj3d
mds = sklearn.manifold.MDS(n_components=3, metric=metric, eps=1e-9, dissimilarity="precomputed")
points = mds.fit(df).embedding_
fig = plt.figure()
ax = fig.add_subplot(111, projection = '3d')
xs, ys, zs = np.split(points, 3, axis=1)
ax.scatter(xs,ys,zs, c=assignments, s=40)
# make labels move as the user rotates the graph
global labels_and_points # a hack for namespace problems
labels_and_points = []
for feature, x, y, z in zip(items, xs, ys, zs):
x2, y2, _ = proj3d.proj_transform(x,y,z, ax.get_proj())
label = plt.annotate(
feature,
xy = (x2, y2), xytext = (-5, 5),
textcoords = 'offset points', ha = 'right', va = 'bottom')
labels_and_points.append((label, x, y, z))
def update_position(e):
for label, x, y, z in labels_and_points:
x2, y2, _ = proj3d.proj_transform(x, y, z, ax.get_proj())
label.xy = x2,y2
label.update_positions(fig.canvas.renderer)
fig.canvas.draw()
fig.canvas.mpl_connect('motion_notify_event', update_position)
os.makedirs('figs', exist_ok=True)
plt.savefig('figs/{}_mds{}.png'.format(name, dim))
if interactive:
plt.show()
return mds.stress_
def dendrogram(df, name="", method='complete'):
"""Plots a dendrogram using hierarchical clustering. Returns inconsistency.
See scipy.cluster.hierarchy.linkage for details regarding
possible clustering methods.
"""
items = df.index
plt.clf()
clustering = scipy.cluster.hierarchy.linkage(df, method=method)
scipy.cluster.hierarchy.dendrogram(clustering, orientation='left', truncate_mode=None,
labels=items, color_threshold=0)
plt.tight_layout()
os.makedirs('figs', exist_ok=True)
plt.savefig('figs/{}_dendrogram.png'.format(name))
return scipy.cluster.hierarchy.inconsistent(clustering)
if __name__ == '__main__':
import pandas as pd
a,b,c = 'abc'
df = pd.DataFrame({a:{a:0, b:1, c:9},
b:{a:1, b:0, c:9},
c:{a:9, b:9, c:0},})
dendrogram(df, "foobar")
mds(df, "foobar", interactive=True, clusters=2)