Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions src/vetting/centroiding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,51 @@ def weighted_average(values, weights, axis=None):
return mean, error / len(values) ** 0.5


def plot_centroids_vs_time(lc, centroid_type, label=None):
# tmasks: list of transit mask, one for each planet candidate
tmasks = np.asarray([lc[f"tmask{i}"] for i in range(lc.meta["num_tmasks"])])
oot_mask = tmasks.all(axis=0)
with plt.style.context(lk.MPLSTYLE):
fig, axs = plt.subplots(
2,
1,
figsize=(8, 4 * 2),
sharex=True,
)

if centroid_type == "cent_diff":
col_plot_label_suffix = "centroid column diff"
row_plot_label_suffix = "centroid row diff"
elif centroid_type == "cent":
col_plot_label_suffix = "centroid column"
row_plot_label_suffix = "centroid row"
elif centroid_type == "tr":
col_plot_label_suffix = "centroid detrended"
row_plot_label_suffix = "centroid detrended"
else:
raise ValueError(f"Unsupported centroid type: {centroid_type}")

ax = lc[oot_mask].scatter(ax=axs[0], column=f"x{centroid_type}", alpha=0.3, label=f"OOT {col_plot_label_suffix}")
for idx in range(len(tmasks)):
# Transits of planet IDX
it_mask = ~tmasks[idx]
lc[it_mask].scatter(ax=ax, column=f"x{centroid_type}", s=16, marker="*", label=f"Pl {idx + 1} {col_plot_label_suffix}")

ax = lc[oot_mask].scatter(ax=axs[1], column=f"y{centroid_type}", alpha=0.3, label=f"OOT {row_plot_label_suffix}")
for idx in range(len(tmasks)):
# Transits of planet IDX
it_mask = ~tmasks[idx]
lc[it_mask].scatter(ax=ax, column=f"y{centroid_type}", s=16, marker="*", label=f"Pl {idx + 1} {row_plot_label_suffix}")

plt.subplots_adjust(hspace=0)

if label is None:
label = lc.meta.get("LABEL")
if label is not None:
fig.suptitle(label)
return fig


def centroid_test(
tpfs,
periods,
Expand All @@ -35,6 +80,7 @@ def centroid_test(
kernel=21,
aperture_mask="pipeline",
plot=True,
include_diagnostics=False,
nsamp=100,
transit_depths=None,
labels=None,
Expand Down Expand Up @@ -117,11 +163,14 @@ def centroid_test(

nplanets = len(periods)
r = {}
for key in ["figs", "pvalues", "centroid_offset_detected"]:
for key in ["figs", "pvalues", "pvalues_x", "pvalues_y", "centroid_offset_detected"]:
r[key] = []
if transit_depths is not None:
for key in ["1sigma_error"]:
r[key] = []
if include_diagnostics:
for key in ["lc_list", "tpf_m_list", "diagnostics_figs"]:
r[key] = []

for tpf in tpfs:
if tpf.mission.lower() in ["kepler", "ktwo", "k2"]:
Expand All @@ -146,6 +195,9 @@ def centroid_test(
mask &= np.nan_to_num(tpf.to_lightcurve(aperture_mask=aper).flux) != 0
tpf = tpf[mask]
lc = tpf.to_lightcurve(aperture_mask=aper)
if include_diagnostics:
r["tpf_m_list"].append(tpf)
r["lc_list"].append(lc)

tmasks = []
for period, t0, duration in zip(periods, t0s, durs):
Expand All @@ -155,6 +207,14 @@ def centroid_test(
)
tmasks.append(t_mask)
tmasks = np.asarray(tmasks)
if include_diagnostics:
# put tmasks as a list of extra columns, one for each planet
# this is done (instead of, say, putting the entire `tmasks`` in meta)
# to make it easy for users to truncate the lc and plot only a section of the time range
# (the tmasks would be truncated accordingly, by the virtue that they are basic columns)
for i in range(len(tmasks)):
lc[f"tmask{i}"] = tmasks[i]
lc.meta["num_tmasks"] = len(tmasks)
Y, X = np.mgrid[: tpf.shape[1], : tpf.shape[2]]
X = (X[aper][:, None] * np.ones(tpf.shape[0])).T
Y = (Y[aper][:, None] * np.ones(tpf.shape[0])).T
Expand Down Expand Up @@ -236,6 +296,22 @@ def centroid_test(
ycent[:, 0] - ytr, ycent[:, 1], size=(nsamp, len(ycent))
)

if include_diagnostics:
lc["xcent"] = xcent[:, 0]
lc["xcent_err"] = xcent[:, 1]
lc["ycent"] = ycent[:, 0]
lc["ycent_err"] = ycent[:, 1]
lc["xtr"] = xtr
lc["ytr"] = ytr
lc["xcent_diff"] = lc["xcent"] - lc["xtr"]
lc["ycent_diff"] = lc["ycent"] - lc["ytr"]
lc.meta["xsamps"] = xsamps
lc.meta["ysamps"] = ysamps

if include_diagnostics and plot:
diagnostics_fig = plot_centroids_vs_time(lc, "cent_diff", label=_label(tpf))
r["diagnostics_figs"].append(diagnostics_fig)

if plot:
with plt.style.context("seaborn-white"):
fig, axs = plt.subplots(
Expand All @@ -249,7 +325,7 @@ def centroid_test(
if not hasattr(axs, "__iter__"):
axs = [axs]

pvalues, sigma1, centroid_offset_detected = [], [], []
pvalues, pvalues_x, pvalues_y, sigma1, centroid_offset_detected = [], [], [], [], []
for idx in range(nplanets):
# NO Transits
k1 = (tmasks).all(axis=0)
Expand Down Expand Up @@ -303,12 +379,14 @@ def centroid_test(
)
axs[idx].legend(loc="upper left")

ps = []
ps, ps_x, ps_y = [], [], []
# ps1 = []
for x1, y1 in zip(xsamps, ysamps):
px = ttest_ind(x1[k1], x1[k2], equal_var=False)
py = ttest_ind(y1[k1], y1[k2], equal_var=False)
ps.append(np.mean([px.pvalue, py.pvalue]))
ps_x.append(px.pvalue)
ps_y.append(py.pvalue)

if transit_depths is not None:
# Weighted average and weighted standard deviation of out of transit
Expand All @@ -320,18 +398,18 @@ def centroid_test(
pos_err = np.hypot(np.hypot(a1[1], b1[1]), np.hypot(a2[1], b2[1]))
sigma1.append(pixel_scale * pos_err / transit_depths[idx])
if k2.sum() == 0:
pvalue = 1
pvalue, pvalue_x, pvalue_y = 1, 1, 1
else:
pvalue = np.mean(ps)
pvalue, pvalue_x, pvalue_y = np.mean(ps), np.mean(ps_x), np.mean(ps_y)
pvalues.append(pvalue)
if pvalue < 0.05:
centroid_offset_detected.append(True)
else:
centroid_offset_detected.append(False)
pvalues_x.append(pvalue_x)
pvalues_y.append(pvalue_y)
p_centroid_offset_detected = np.min([pvalue_x, pvalue_y]) < 0.05
centroid_offset_detected.append(p_centroid_offset_detected)

if plot:
with plt.style.context("seaborn-white"):
if pvalue >= 0.05:
if not p_centroid_offset_detected:
label = f"No Significant Offset (p-value: {pvalue:.2E})"
if transit_depths is not None:
label = (
Expand Down Expand Up @@ -369,5 +447,7 @@ def centroid_test(
if plot:
r["figs"].append(fig)
r["pvalues"].append(tuple(pvalues))
r["pvalues_x"].append(tuple(pvalues_x))
r["pvalues_y"].append(tuple(pvalues_y))

return r