22Utilities for convenient
33exporting the MPDS data
44"""
5+
56import os
67import random
7-
88import ujson as json
9- import pandas as pd
9+ import polars as pl
10+ from typing import Union
1011
1112
1213class MPDSExport (object ):
13-
1414 export_dir = "/tmp/_MPDS"
1515
1616 human_names = {
17- ' length' : ' Bond lengths, A' ,
18- ' occurrence' : ' Counts' ,
19- ' bandgap' : ' Band gap, eV'
17+ " length" : " Bond lengths, A" ,
18+ " occurrence" : " Counts" ,
19+ " bandgap" : " Band gap, eV" ,
2020 }
2121
2222 @classmethod
@@ -32,15 +32,21 @@ def _gen_basename(cls):
3232 basename = []
3333 random .seed ()
3434 for _ in range (12 ):
35- basename .append (random .choice ("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" ))
35+ basename .append (
36+ random .choice (
37+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
38+ )
39+ )
3640 return "" .join (basename )
3741
3842 @classmethod
39- def _get_title (cls , term ):
43+ def _get_title (cls , term : Union [str , int ]):
44+ if isinstance (term , int ):
45+ return str (term )
4046 return cls .human_names .get (term , term .capitalize ())
4147
4248 @classmethod
43- def save_plot (cls , data , columns , plottype , fmt = ' json' , ** kwargs ):
49+ def save_plot (cls , data , columns , plottype , fmt = " json" , ** kwargs ):
4450 """
4551 Exports the data in the following formats for plotting:
4652
@@ -50,87 +56,114 @@ def save_plot(cls, data, columns, plottype, fmt='json', **kwargs):
5056 cls ._verify_export_dir ()
5157 plot = {"use_visavis_type" : plottype , "payload" : {}}
5258
53- if isinstance (data , pd .DataFrame ):
54- iter_data = data . iterrows
55- pointers = columns
56- else :
57- iter_data = lambda : enumerate ( data )
58- pointers = range ( len ( data [ 0 ]) )
59+ if not isinstance (data , pl .DataFrame ):
60+ raise TypeError ( "The ' data' parameter must be a Polars DataFrame" )
61+
62+ # сheck that columns are valid
63+ if not all ( col in data . columns for col in columns ):
64+ raise ValueError ( "Some specified columns are not in the DataFrame" )
5965
60- if fmt == 'csv' :
66+ if fmt == "csv" :
67+ # export to CSV
6168 fmt_export = os .path .join (cls .export_dir , cls ._gen_basename () + ".csv" )
62- f_export = open (fmt_export , "w" )
63- f_export .write ("%s\n " % "," .join (map (str , columns )))
64- for _ , row in iter_data ():
65- f_export .write ("%s\n " % "," .join ([str (row [i ]) for i in pointers ]))
66- f_export .close ()
69+ with open (fmt_export , "w" ) as f_export :
70+ f_export .write ("," .join (columns ) + "\n " )
71+ for row in data .select (columns ).iter_rows ():
72+ f_export .write ("," .join (map (str , row )) + "\n " )
6773
68- else :
74+ elif fmt == "json" :
75+ # export to JSON
6976 fmt_export = os .path .join (cls .export_dir , cls ._gen_basename () + ".json" )
70- f_export = open (fmt_export , "w" )
71-
72- if plottype == 'bar' :
73-
74- plot ["payload" ] = {"x" : [], "y" : [], "xtitle" : cls ._get_title (columns [0 ]), "ytitle" : cls ._get_title (columns [1 ])}
75-
76- for _ , row in iter_data ():
77- plot ["payload" ]["x" ].append (row [pointers [0 ]])
78- plot ["payload" ]["y" ].append (row [pointers [1 ]])
79-
80- elif plottype == 'plot3d' :
77+ with open (fmt_export , "w" ) as f_export :
78+ if plottype == "bar" :
79+ # bar plot payload
80+ plot ["payload" ] = {
81+ "x" : [data [columns [0 ]].to_list ()],
82+ "y" : data [columns [1 ]].to_list (),
83+ "xtitle" : cls ._get_title (columns [0 ]),
84+ "ytitle" : cls ._get_title (columns [1 ]),
85+ }
86+
87+ elif plottype == "plot3d" :
88+ # 3D plot payload
89+ plot ["payload" ] = {
90+ "points" : {"x" : [], "y" : [], "z" : [], "labels" : []},
91+ "meshes" : [],
92+ "xtitle" : cls ._get_title (columns [0 ]),
93+ "ytitle" : cls ._get_title (columns [1 ]),
94+ "ztitle" : cls ._get_title (columns [2 ]),
95+ }
96+ recent_mesh = None
97+ for row in data .iter_rows ():
98+ plot ["payload" ]["points" ]["x" ].append (
99+ row [data .columns .index (columns [0 ])]
100+ )
101+ plot ["payload" ]["points" ]["y" ].append (
102+ row [data .columns .index (columns [1 ])]
103+ )
104+ plot ["payload" ]["points" ]["z" ].append (
105+ row [data .columns .index (columns [2 ])]
106+ )
107+ plot ["payload" ]["points" ]["labels" ].append (
108+ row [data .columns .index (columns [3 ])]
109+ )
110+
111+ if row [data .columns .index (columns [4 ])] != recent_mesh :
112+ plot ["payload" ]["meshes" ].append (
113+ {"x" : [], "y" : [], "z" : []}
114+ )
115+ recent_mesh = row [data .columns .index (columns [4 ])]
116+
117+ if plot ["payload" ]["meshes" ]:
118+ plot ["payload" ]["meshes" ][- 1 ]["x" ].append (
119+ row [data .columns .index (columns [0 ])]
120+ )
121+ plot ["payload" ]["meshes" ][- 1 ]["y" ].append (
122+ row [data .columns .index (columns [1 ])]
123+ )
124+ plot ["payload" ]["meshes" ][- 1 ]["z" ].append (
125+ row [data .columns .index (columns [2 ])]
126+ )
127+ else :
128+ raise RuntimeError (f"Error: { plottype } is an unknown plot type" )
129+
130+ if kwargs :
131+ plot ["payload" ].update (kwargs )
132+
133+ # write JSON to file
134+ f_export .write (json .dumps (plot , escape_forward_slashes = False , indent = 4 ))
81135
82- plot ["payload" ]["points" ] = {"x" : [], "y" : [], "z" : [], "labels" : []}
83- plot ["payload" ]["meshes" ] = []
84- plot ["payload" ]["xtitle" ] = cls ._get_title (columns [0 ])
85- plot ["payload" ]["ytitle" ] = cls ._get_title (columns [1 ])
86- plot ["payload" ]["ztitle" ] = cls ._get_title (columns [2 ])
87- recent_mesh = 0
88-
89- for _ , row in iter_data ():
90- plot ["payload" ]["points" ]["x" ].append (row [pointers [0 ]])
91- plot ["payload" ]["points" ]["y" ].append (row [pointers [1 ]])
92- plot ["payload" ]["points" ]["z" ].append (row [pointers [2 ]])
93- plot ["payload" ]["points" ]["labels" ].append (row [pointers [3 ]])
94-
95- if row [4 ] != recent_mesh :
96- plot ["payload" ]["meshes" ].append ({"x" : [], "y" : [], "z" : []})
97- recent_mesh = row [4 ]
98-
99- if plot ["payload" ]["meshes" ]:
100- plot ["payload" ]["meshes" ][- 1 ]["x" ].append (row [pointers [0 ]])
101- plot ["payload" ]["meshes" ][- 1 ]["y" ].append (row [pointers [1 ]])
102- plot ["payload" ]["meshes" ][- 1 ]["z" ].append (row [pointers [2 ]])
103-
104- if kwargs :
105- plot ["payload" ].update (kwargs )
106-
107- else : raise RuntimeError ("\r \n Error: %s is an unknown plot type" % plottype )
108-
109- f_export .write (json .dumps (plot , escape_forward_slashes = False , indent = 4 ))
110- f_export .close ()
136+ else :
137+ raise ValueError (f"Unsupported format: { fmt } " )
111138
112139 return fmt_export
113140
114141 @classmethod
115142 def save_df (cls , frame , tag ):
116143 cls ._verify_export_dir ()
144+ if not isinstance (frame , pl .DataFrame ):
145+ raise TypeError ("Input frame must be a Polars DataFrame" )
146+
117147 if tag is None :
118- tag = '-'
148+ tag = "-"
119149
120- pkl_export = os .path .join (cls .export_dir , 'df' + str (tag ) + '_' + cls ._gen_basename () + ".pkl" )
121- frame .to_pickle (pkl_export , protocol = 2 ) # Py2-3 compat
150+ pkl_export = os .path .join (
151+ cls .export_dir , "df" + str (tag ) + "_" + cls ._gen_basename () + ".pkl"
152+ )
153+ frame .write_parquet (pkl_export ) # cos pickle is not supported in polars
122154 return pkl_export
123155
124156 @classmethod
125157 def save_model (cls , skmodel , tag ):
126-
127158 import _pickle as cPickle
128159
129160 cls ._verify_export_dir ()
130161 if tag is None :
131- tag = '-'
162+ tag = "-"
132163
133- pkl_export = os .path .join (cls .export_dir , 'ml' + str (tag ) + '_' + cls ._gen_basename () + ".pkl" )
134- with open (pkl_export , 'wb' ) as f :
164+ pkl_export = os .path .join (
165+ cls .export_dir , "ml" + str (tag ) + "_" + cls ._gen_basename () + ".pkl"
166+ )
167+ with open (pkl_export , "wb" ) as f :
135168 cPickle .dump (skmodel , f )
136169 return pkl_export
0 commit comments