Skip to content

Commit 3ea9646

Browse files
authored
Merge pull request #18 from alinzh/master
Change Pandas on Polars package
2 parents 75ccb3f + 7ef9fb0 commit 3ea9646

File tree

7 files changed

+389
-252
lines changed

7 files changed

+389
-252
lines changed

mpds_client/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import sys
32

43
from .retrieve_MPDS import MPDSDataTypes, APIError, MPDSDataRetrieval
@@ -7,4 +6,4 @@
76

87
MIN_PY_VER = (3, 5)
98

10-
assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER)
9+
assert sys.version_info >= MIN_PY_VER, "Python version must be >= {}".format(MIN_PY_VER)

mpds_client/errors.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
21
class APIError(Exception):
32
"""
43
Simple error handling
54
"""
5+
66
codes = {
7-
204: 'No Results',
8-
400: 'Bad Request',
9-
401: 'Unauthorized',
10-
402: 'Unauthorized (Payment Required)',
11-
403: 'Forbidden',
12-
404: 'Not Found',
13-
413: 'Too Much Data Given',
14-
429: 'Too Many Requests (Rate Limiting)',
15-
500: 'Internal Server Error',
16-
501: 'Not Implemented',
17-
503: 'Service Unavailable'
7+
204: "No Results",
8+
400: "Bad Request",
9+
401: "Unauthorized",
10+
402: "Unauthorized (Payment Required)",
11+
403: "Forbidden",
12+
404: "Not Found",
13+
413: "Too Much Data Given",
14+
429: "Too Many Requests (Rate Limiting)",
15+
500: "Internal Server Error",
16+
501: "Not Implemented",
17+
503: "Service Unavailable",
1818
}
1919

2020
def __init__(self, msg, code=0):
@@ -23,4 +23,8 @@ def __init__(self, msg, code=0):
2323
self.code = code
2424

2525
def __str__(self):
26-
return "HTTP error code %s: %s (%s)" % (self.code, self.codes.get(self.code, 'Communication Error'), self.msg)
26+
return "HTTP error code %s: %s (%s)" % (
27+
self.code,
28+
self.codes.get(self.code, "Communication Error"),
29+
self.msg,
30+
)

mpds_client/export_MPDS.py

Lines changed: 102 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
Utilities for convenient
33
exporting the MPDS data
44
"""
5+
56
import os
67
import random
7-
88
import ujson as json
9-
import pandas as pd
9+
import polars as pl
10+
from typing import Union
1011

1112

1213
class MPDSExport(object):
13-
1414
export_dir = "/tmp/_MPDS"
1515

1616
human_names = {
17-
'length': 'Bond lengths, A',
18-
'occurrence': 'Counts',
19-
'bandgap': 'Band gap, eV'
17+
"length": "Bond lengths, A",
18+
"occurrence": "Counts",
19+
"bandgap": "Band gap, eV",
2020
}
2121

2222
@classmethod
@@ -32,15 +32,21 @@ def _gen_basename(cls):
3232
basename = []
3333
random.seed()
3434
for _ in range(12):
35-
basename.append(random.choice("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"))
35+
basename.append(
36+
random.choice(
37+
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
38+
)
39+
)
3640
return "".join(basename)
3741

3842
@classmethod
39-
def _get_title(cls, term):
43+
def _get_title(cls, term: Union[str, int]):
44+
if isinstance(term, int):
45+
return str(term)
4046
return cls.human_names.get(term, term.capitalize())
4147

4248
@classmethod
43-
def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
49+
def save_plot(cls, data, columns, plottype, fmt="json", **kwargs):
4450
"""
4551
Exports the data in the following formats for plotting:
4652
@@ -50,87 +56,114 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
5056
cls._verify_export_dir()
5157
plot = {"use_visavis_type": plottype, "payload": {}}
5258

53-
if isinstance(data, pd.DataFrame):
54-
iter_data = data.iterrows
55-
pointers = columns
56-
else:
57-
iter_data = lambda: enumerate(data)
58-
pointers = range(len(data[0]))
59+
if not isinstance(data, pl.DataFrame):
60+
raise TypeError("The 'data' parameter must be a Polars DataFrame")
61+
62+
# сheck that columns are valid
63+
if not all(col in data.columns for col in columns):
64+
raise ValueError("Some specified columns are not in the DataFrame")
5965

60-
if fmt == 'csv':
66+
if fmt == "csv":
67+
# export to CSV
6168
fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".csv")
62-
f_export = open(fmt_export, "w")
63-
f_export.write("%s\n" % ",".join(map(str, columns)))
64-
for _, row in iter_data():
65-
f_export.write("%s\n" % ",".join([str(row[i]) for i in pointers]))
66-
f_export.close()
69+
with open(fmt_export, "w") as f_export:
70+
f_export.write(",".join(columns) + "\n")
71+
for row in data.select(columns).iter_rows():
72+
f_export.write(",".join(map(str, row)) + "\n")
6773

68-
else:
74+
elif fmt == "json":
75+
# export to JSON
6976
fmt_export = os.path.join(cls.export_dir, cls._gen_basename() + ".json")
70-
f_export = open(fmt_export, "w")
71-
72-
if plottype == 'bar':
73-
74-
plot["payload"] = {"x": [], "y": [], "xtitle": cls._get_title(columns[0]), "ytitle": cls._get_title(columns[1])}
75-
76-
for _, row in iter_data():
77-
plot["payload"]["x"].append(row[pointers[0]])
78-
plot["payload"]["y"].append(row[pointers[1]])
79-
80-
elif plottype == 'plot3d':
77+
with open(fmt_export, "w") as f_export:
78+
if plottype == "bar":
79+
# bar plot payload
80+
plot["payload"] = {
81+
"x": [data[columns[0]].to_list()],
82+
"y": data[columns[1]].to_list(),
83+
"xtitle": cls._get_title(columns[0]),
84+
"ytitle": cls._get_title(columns[1]),
85+
}
86+
87+
elif plottype == "plot3d":
88+
# 3D plot payload
89+
plot["payload"] = {
90+
"points": {"x": [], "y": [], "z": [], "labels": []},
91+
"meshes": [],
92+
"xtitle": cls._get_title(columns[0]),
93+
"ytitle": cls._get_title(columns[1]),
94+
"ztitle": cls._get_title(columns[2]),
95+
}
96+
recent_mesh = None
97+
for row in data.iter_rows():
98+
plot["payload"]["points"]["x"].append(
99+
row[data.columns.index(columns[0])]
100+
)
101+
plot["payload"]["points"]["y"].append(
102+
row[data.columns.index(columns[1])]
103+
)
104+
plot["payload"]["points"]["z"].append(
105+
row[data.columns.index(columns[2])]
106+
)
107+
plot["payload"]["points"]["labels"].append(
108+
row[data.columns.index(columns[3])]
109+
)
110+
111+
if row[data.columns.index(columns[4])] != recent_mesh:
112+
plot["payload"]["meshes"].append(
113+
{"x": [], "y": [], "z": []}
114+
)
115+
recent_mesh = row[data.columns.index(columns[4])]
116+
117+
if plot["payload"]["meshes"]:
118+
plot["payload"]["meshes"][-1]["x"].append(
119+
row[data.columns.index(columns[0])]
120+
)
121+
plot["payload"]["meshes"][-1]["y"].append(
122+
row[data.columns.index(columns[1])]
123+
)
124+
plot["payload"]["meshes"][-1]["z"].append(
125+
row[data.columns.index(columns[2])]
126+
)
127+
else:
128+
raise RuntimeError(f"Error: {plottype} is an unknown plot type")
129+
130+
if kwargs:
131+
plot["payload"].update(kwargs)
132+
133+
# write JSON to file
134+
f_export.write(json.dumps(plot, escape_forward_slashes=False, indent=4))
81135

82-
plot["payload"]["points"] = {"x": [], "y": [], "z": [], "labels": []}
83-
plot["payload"]["meshes"] = []
84-
plot["payload"]["xtitle"] = cls._get_title(columns[0])
85-
plot["payload"]["ytitle"] = cls._get_title(columns[1])
86-
plot["payload"]["ztitle"] = cls._get_title(columns[2])
87-
recent_mesh = 0
88-
89-
for _, row in iter_data():
90-
plot["payload"]["points"]["x"].append(row[pointers[0]])
91-
plot["payload"]["points"]["y"].append(row[pointers[1]])
92-
plot["payload"]["points"]["z"].append(row[pointers[2]])
93-
plot["payload"]["points"]["labels"].append(row[pointers[3]])
94-
95-
if row[4] != recent_mesh:
96-
plot["payload"]["meshes"].append({"x": [], "y": [], "z": []})
97-
recent_mesh = row[4]
98-
99-
if plot["payload"]["meshes"]:
100-
plot["payload"]["meshes"][-1]["x"].append(row[pointers[0]])
101-
plot["payload"]["meshes"][-1]["y"].append(row[pointers[1]])
102-
plot["payload"]["meshes"][-1]["z"].append(row[pointers[2]])
103-
104-
if kwargs:
105-
plot["payload"].update(kwargs)
106-
107-
else: raise RuntimeError("\r\nError: %s is an unknown plot type" % plottype)
108-
109-
f_export.write(json.dumps(plot, escape_forward_slashes=False, indent=4))
110-
f_export.close()
136+
else:
137+
raise ValueError(f"Unsupported format: {fmt}")
111138

112139
return fmt_export
113140

114141
@classmethod
115142
def save_df(cls, frame, tag):
116143
cls._verify_export_dir()
144+
if not isinstance(frame, pl.DataFrame):
145+
raise TypeError("Input frame must be a Polars DataFrame")
146+
117147
if tag is None:
118-
tag = '-'
148+
tag = "-"
119149

120-
pkl_export = os.path.join(cls.export_dir, 'df' + str(tag) + '_' + cls._gen_basename() + ".pkl")
121-
frame.to_pickle(pkl_export, protocol=2) # Py2-3 compat
150+
pkl_export = os.path.join(
151+
cls.export_dir, "df" + str(tag) + "_" + cls._gen_basename() + ".pkl"
152+
)
153+
frame.write_parquet(pkl_export) # cos pickle is not supported in polars
122154
return pkl_export
123155

124156
@classmethod
125157
def save_model(cls, skmodel, tag):
126-
127158
import _pickle as cPickle
128159

129160
cls._verify_export_dir()
130161
if tag is None:
131-
tag = '-'
162+
tag = "-"
132163

133-
pkl_export = os.path.join(cls.export_dir, 'ml' + str(tag) + '_' + cls._gen_basename() + ".pkl")
134-
with open(pkl_export, 'wb') as f:
164+
pkl_export = os.path.join(
165+
cls.export_dir, "ml" + str(tag) + "_" + cls._gen_basename() + ".pkl"
166+
)
167+
with open(pkl_export, "wb") as f:
135168
cPickle.dump(skmodel, f)
136169
return pkl_export

0 commit comments

Comments
 (0)