diff --git a/exampleScripts/deduplicate.py b/exampleScripts/deduplicate.py new file mode 100644 index 0000000..3c9eba1 --- /dev/null +++ b/exampleScripts/deduplicate.py @@ -0,0 +1,125 @@ +import numpy as np +import pandas as pd +import argparse +from multiprocessing import Pool +from astropy.coordinates import SkyCoord +parser = argparse.ArgumentParser( + prog = 'deduplicate', + description = 'Given a csv file containing columns with at least wm_ra, wm_decl, datapoints, rb_smaj, runcat reduce source density by eliminating sources within 5 beamwidths, keeping the source with the highest number of datapoints') +parser.add_argument('--csvfile',type=str) +parser.add_argument('--pullsql',action='store_true', help='option to pull data from the database, assuming trapenvvar.bash is already sourced.') +parser.add_argument('--dataset',type=int,help='dataset to pull from sql') +parser.add_argument('--getoutput',action='store_true',help='Also dump csv file for input into slidingetawindow.py') +parser.add_argument('--minrms',type=float, default=1e-16, help='Exclude measurements below this rmsqc value') +parser.add_argument('--maxrms',type=float, default=1e16,help='Exclude measurements above this rmsqc value') +args = parser.parse_args() +def labelrgn(myid): +# print(myid) + myra = float(dat['wm_ra'][dat['runcat']==myid].unique()) + mydec = float(dat['wm_decl'][dat['runcat']==myid].unique()) + mypts = int(dat['datapoints'][dat['runcat']==myid].unique()) + mymaj = np.amax(dat['rb_smaj'][dat['runcat']==myid]) + mysc = SkyCoord(myra,mydec,unit='deg') + otherid = dat['runcat'][dat['runcat']!=myid].unique() + otherra = dat['wm_ra'][dat['runcat']!=myid].unique() + otherdec = dat['wm_decl'][dat['runcat']!=myid].unique() + otheridpts = np.unique(dat[['runcat', 'datapoints']].to_numpy(),axis=0) + alldist = mysc.separation(SkyCoord(otherra,otherdec,unit='deg')).degree + samergn = alldist < (5*mymaj) + if np.sum(samergn) > 0: + outarr = np.zeros(np.sum(samergn)+1,dtype=[('id','i8'),('rgnid','i8')]) + # print(np.sum(samergn),samergn.shape,otherid.shape,) + # print(len(outarr)) + outarr['id'][0:-1] = otherid[samergn] + outarr['id'][-1] = myid + outarr['rgnid'] = np.amin(outarr['id']) + else: + outarr = np.zeros(1,dtype=[('id','i8'),('rgnid','i8'),('pts','i8')]) + outarr['id'] = myid + outarr['rgnid']=myid + return [(o['id'],o['rgnid']) for o in outarr] + +def initdb(): + import psycopg2 + import os + from psycopg2.extensions import register_adapter, AsIs + # psycopg2 complains about numpy datatypes, this avoids that error. Got this from stackoverflow but don't recall where + def addapt_numpy_float64(numpy_float64): + return AsIs(numpy_float64) + def addapt_numpy_int64(numpy_int64): + return AsIs(numpy_int64) + register_adapter(np.float64, addapt_numpy_float64) + register_adapter(np.int64, addapt_numpy_int64) + host = os.getenv('TKP_DBHOST') + port = os.getenv('TKP_DBPORT') + user = os.getenv('TKP_DBUSER') + password = os.getenv('TKP_DBPASSWORD') + database = os.getenv('TKP_DBNAME') + if password is None: + conn = psycopg2.connect("dbname="+database+" user="+user+" host="+host) + else: + conn = psycopg2.connect("dbname="+database+" user="+user+" password="+password+" host="+host) + + ################################################################################ + # Typical psycopg2 setup + conn = psycopg2.connect("dbname="+database+" user="+user+" password="+password+" host="+host) + return conn.cursor() + +if __name__=='__main__': + if args.pullsql: + cur = initdb() + + + query = """SELECT image.rms_qc, image.rb_smaj,runningcatalog.wm_ra, runningcatalog.wm_decl,runningcatalog.datapoints, image.taustart_ts, image.band, image.freq_eff, extractedsource.f_peak, extractedsource.f_peak_err, extractedsource.f_int, extractedsource.f_int_err, assocxtrsource.runcat, extractedsource.det_sigma, extractedsource.ra, extractedsource.decl, sqlseparation(extractedsource.ra,extractedsource.decl, skyregion.centre_ra, skyregion.centre_decl), image.dataset FROM extractedsource JOIN assocxtrsource ON assocxtrsource.xtrsrc=extractedsource.id JOIN image ON extractedsource.image=image.id JOIN skyregion ON image.skyrgn=skyregion.id JOIN runningcatalog ON assocxtrsource.runcat=runningcatalog.id WHERE image.dataset=%s AND image.rms_qc<%s AND image.rms_qc>%s GROUP BY image.rb_smaj,image.taustart_ts, image.band, image.freq_eff, extractedsource.f_peak, extractedsource.f_peak_err, extractedsource.f_int, extractedsource.f_int_err, assocxtrsource.runcat, extractedsource.det_sigma, extractedsource.ra, extractedsource.decl, skyregion.centre_ra, skyregion.centre_decl, image.dataset, runningcatalog.wm_ra, runningcatalog.wm_decl, runningcatalog.datapoints, image.rms_qc HAVING sqlseparation(extractedsource.ra, extractedsource.decl, skyregion.centre_ra, skyregion.centre_decl) < 0.8""" + print('Querying the database') + cur.execute(query,(args.dataset,args.maxrms, args.minrms)) + + dat = pd.DataFrame(cur.fetchall(), columns=['rms_qc','rb_smaj','wm_ra','wm_decl','datapoints','taustart_ts', + 'band','freq_eff','f_peak','f_peak_err','f_int','f_int_err', + 'runcat','det_sigma','ra','decl','dist','dataset']) + print('Put data into dataframe!') + print(dat['rms_qc'].min(), dat['rms_qc'].max()) + elif args.csvfile: + dat = pd.read_csv(args.csvfile) + + + else: + print("Please specify data source") + import sys + sys.exit(1) + + uniqueids = dat['runcat'].unique() + with Pool() as pool: + results = pool.map(labelrgn, uniqueids) + + flatresults = [] + for r in results: + flatresults.extend(r) + locarray = np.array(flatresults) + keepids = [] + for uniqueloc in np.unique(locarray): + idinlocation = uniqueids[np.isin(uniqueids,locarray[locarray[:,1]==uniqueloc,0])] + subdat = dat[np.isin(dat['runcat'],idinlocation)] + keepdat = subdat[subdat['datapoints']==subdat['datapoints'].max()] + for myid in keepdat['runcat'].unique(): + keepids.append(myid) + print(keepids) + with open('dumpdedupedids.txt','w') as f: + f.write(str(tuple(keepids))) + + if args.getoutput: + import os + database = os.getenv('TKP_DBNAME') + if not args.pullsql: + cur = initdb() + query = """SELECT image.taustart_ts, image.band, image.freq_eff, extractedsource.f_peak, extractedsource.f_peak_err, extractedsource.f_int, extractedsource.f_int_err, assocxtrsource.runcat, extractedsource.det_sigma, extractedsource.ra, extractedsource.decl, sqlseparation(extractedsource.ra,extractedsource.decl, skyregion.centre_ra, skyregion.centre_decl), image.dataset FROM extractedsource JOIN assocxtrsource ON assocxtrsource.xtrsrc=extractedsource.id JOIN image ON extractedsource.image=image.id JOIN skyregion ON image.skyrgn=skyregion.id WHERE assocxtrsource.runcat IN %s AND image.rms_qc>%s AND image.rms_qc<%s ORDER BY assocxtrsource.runcat ASC ;""" + cur.execute(query, (tuple(keepids),args.minrms, args.maxrms,)) + outdf = pd.DataFrame(cur.fetchall(),columns=['taustart_ts','band','freq_eff','f_peak','f_peak_err', + 'f_int','f_int_err','runcat','det_sigma','ra','decl','dist','dataset']) + with open(f'{database}deduped.csv','w') as f: + f.write(outdf.to_csv(index=False)) + + + + + diff --git a/exampleScripts/fastanimate.py b/exampleScripts/fastanimate.py new file mode 100644 index 0000000..7be2b9f --- /dev/null +++ b/exampleScripts/fastanimate.py @@ -0,0 +1,161 @@ +import datetime +import imageio +import matplotlib.pyplot as plt +import os +import numpy as np +import argparse +from astropy.io import fits +from astropy.coordinates import SkyCoord +from astropy.wcs import WCS +from astropy.visualization.wcsaxes import Quadrangle +from astropy import units as u +import psycopg2 +from psycopg2.extensions import register_adapter, AsIs +from multiprocessing import Pool + +# psycopg2 complains about numpy datatypes, this avoids that error. Got this from stackoverflow but don't recall where +def addapt_numpy_float64(numpy_float64): + return AsIs(numpy_float64) +def addapt_numpy_int64(numpy_int64): + return AsIs(numpy_int64) +register_adapter(np.float64, addapt_numpy_float64) +register_adapter(np.int64, addapt_numpy_int64) + +def makepreimages(imnum,frame): + fig = plt.figure(figsize=(25,10)) + axlist = [] + with fits.open(frame['image']) as hdul: + src = srcdat[0] + hdu = hdul[0] + imdat = hdu.data[0:][0:][0][0] + wcs = WCS(hdu.header, naxis=2) + sc = SkyCoord(src['ra'],src['dec'],unit='deg') + wcslocation = wcs.world_to_pixel(sc) + axlist.append(fig.add_subplot(1,3,1,projection= wcs)) + axlist[0].imshow(imdat, vmin=-2e-5, vmax=np.amax(srcdat['fint']), origin='lower') + axlist[0].set_xlim(wcslocation[0] - 50 , wcslocation[0] + 50) + axlist[0].set_ylim(wcslocation[1] - 50, wcslocation[1] + 50 ) + axlist.append(fig.add_subplot(1,3,2)) + curdates = srcdat['date'][np.abs(srcdat['date'] - src['date']) < 24] + curfints = srcdat['fint'][np.abs(srcdat['date'] - src['date']) < 24] + curfinterrs = srcdat['finterr'][np.abs(srcdat['date'] - src['date']) < 24] + axlist[1].scatter(curdates,curfints) + axlist[1].set_box_aspect(imdat.shape[0]/imdat.shape[1]) + axlist[1].errorbar(curdates,curfints,yerr=curfinterrs, fmt='none') + axlist[1].set_ylim(1e-6,curfints.max()) + axlist[1].set_xlim(curdates.min(),curdates.max()) + axlist[1].set_xlabel('MJD') + axlist[1].set_ylabel('$F_{int}$') + axlist[1].axvline(src['date']) + axlist.append(fig.add_subplot(1,3,3,projection= wcs)) + axlist[2].imshow(imdat, vmin=-2e-5, vmax=np.amax(srcdat['fint']), origin='lower') + axlist[2].add_patch(Quadrangle((sc.ra, sc.dec), 100*u.arcsec, 100*u.arcsec, + edgecolor='white', facecolor='none',transform=axlist[2].get_transform('fk5'))) + + plt.savefig(f'tmpim{imnum}.png') + imnum += 1 + plt.close() +def plotsrcs(imnum,src,srcdat): + fig = plt.figure(figsize=(25,10)) + axlist = [] + with fits.open(src['image']) as hdul: + hdu = hdul[0] + imdat = hdu.data[0:][0:][0][0] + wcs = WCS(hdu.header, naxis=2) + sc = SkyCoord(src['ra'],src['dec'],unit='deg') + sc2 = SkyCoord(srcdat['ra'][0],srcdat['dec'][0],unit='deg') + wcslocation = wcs.world_to_pixel(sc) + wcslocation2 = wcs.world_to_pixel(sc2) + axlist.append(fig.add_subplot(1,3,1,projection= wcs)) + axlist[0].imshow(imdat, vmin=-2e-5, vmax=np.amax(srcdat['fint']), origin='lower') + axlist[0].scatter(wcslocation[0],wcslocation[1],marker='x',color='red') + axlist[0].set_xlim(wcslocation2[0] - 50 , wcslocation2[0] + 50) + axlist[0].set_ylim(wcslocation2[1] - 50, wcslocation2[1] + 50 ) + axlist.append(fig.add_subplot(1,3,2)) + curdates = srcdat['date'][np.abs(srcdat['date'] - src['date']) < 24] + curfints = srcdat['fint'][np.abs(srcdat['date'] - src['date']) < 24] + curfinterrs = srcdat['finterr'][np.abs(srcdat['date'] - src['date']) < 24] + axlist[1].scatter(curdates,curfints) + axlist[1].errorbar(curdates,curfints,yerr=curfinterrs, fmt='none') + axlist[1].set_ylim(1e-6,curfints.max()) + axlist[1].set_xlim(curdates.min(),curdates.max()) + axlist[1].set_xlabel('MJD') + axlist[1].set_ylabel('$F_{int}$') + axlist[1].set_box_aspect(imdat.shape[0]/imdat.shape[1]) + axlist[1].axvline(src['date']) + axlist.append(fig.add_subplot(1,3,3,projection= wcs)) + axlist[2].imshow(imdat, vmin=-2e-5, vmax=np.amax(srcdat['fint']), origin='lower') + axlist[2].add_patch(Quadrangle((sc.ra, sc.dec), 100*u.arcsec, 100*u.arcsec, + edgecolor='white', facecolor='none',transform=axlist[2].get_transform('fk5'))) + plt.savefig(f'tmpim{imnum}.png') + plt.close() +# In order to use MJD +start_epoch = datetime.datetime(1858, 11, 17, 00, 00, 00, 00) + +parser = argparse.ArgumentParser() +parser.add_argument('srcid',nargs='*',help='sources to animate') +args = parser.parse_args() +if __name__ == '__main__': + + host = os.getenv('TKP_DBHOST') + port = os.getenv('TKP_DBPORT') + user = os.getenv('TKP_DBUSER') + password = os.getenv('TKP_DBPASSWORD') + database = os.getenv('TKP_DBNAME') + if password is None: + conn = psycopg2.connect("dbname="+database+" user="+user+" host="+host) + else: + conn = psycopg2.connect("dbname="+database+" user="+user+" password="+password+" host="+host) + + ################################################################################ + # Typical psycopg2 setup + conn = psycopg2.connect("dbname="+database+" user="+user+" password="+password+" host="+host) + cur = conn.cursor() + for myid in args.srcid: + + + query = """SELECT extractedsource.f_int, extractedsource.f_int_err,extractedsource.ra, extractedsource.decl, image.url, image.taustart_ts FROM extractedsource JOIN assocxtrsource ON extractedsource.id=assocxtrsource.xtrsrc JOIN image ON extractedsource.image=image.id WHERE assocxtrsource.runcat=%s;""" + cur.execute(query,(myid,)) + + + convdate = [(fint,finterr,ra,dec,image,(d-start_epoch).total_seconds()/3600/24) for fint,finterr,ra,dec,image,d in cur.fetchall()] + srcdat =np.sort(np.array(convdate, dtype=[('fint','f8'),('finterr','f8'),('ra','f8'),('dec','f8'),('image',' startindex: + with Pool() as pool: + pool.starmap(makepreimages, enumerate(alldat[startindex:endindex])) + + + preimnum = endindex - startindex + endimnum = len(srcdat) + preimnum + with Pool() as pool: + pool.starmap(plotsrcs,[(preimnum+ind,src,srcdat) for ind,src in enumerate(srcdat)]) + imnum = endimnum + with imageio.get_writer(f'src{myid}.mp4', mode='I',fps=3) as writer: + for ind in range(imnum): + image = imageio.imread(f'tmpim{ind}.png') + writer.append_data(image) + print(f"wrote video: src{myid}.mp4") + for myfile in [f'tmpim{ind}.png' for ind in range(imnum)]: + print('removing: ',myfile) + os.remove(myfile) + + + + + + + + +