forked from CW-Huang/BayesianHypernet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmushroom_data.py
19 lines (16 loc) · 879 Bytes
/
mushroom_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# modified from https://github.com/hlin117/data-science/blob/master/classification/mushroom.py
import pandas as pd
import numpy as np
"""The mushroom data was obtained through the UCI website. The task
is to classify whether a mushroom is edible or not."""
columns = ["edible", "cap-shape", "cap-surface", "cap-color", "bruises?",
"odor", "gill-attachment", "gill-spacing", "gill-size", "gill-color",
"stalk-shape", "stalk-root", "stalk-surface-above-ring",
"stalk-surface-below-ring", "stalk-color-above-ring",
"stalk-color-below-ring", "veil-type", "veil-color", "ring-number",
"ring-type", "spore-print-color", "population", "habitat"
]
dataset = pd.read_csv("data/mushroom.data", names=columns, index_col=None)
YX = np.array(pd.get_dummies(dataset[columns]))
Y = YX[:, 0:1] # is_edible
X = YX[:, 2:] # edible gets 2 dummies