Skip to content

Commit

Permalink
Improve sRF fit plot.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanoesterle committed Jul 2, 2024
1 parent 7c4d728 commit 6b03adc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
23 changes: 15 additions & 8 deletions djimaging/tables/receptivefield/rf_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datajoint as dj
import numpy as np
from djimaging.utils import math_utils
from djimaging.utils.receptive_fields.plot_rf_utils import plot_srf_gauss_fit

from djimaging.utils.trace_utils import sort_traces
from matplotlib import pyplot as plt
Expand All @@ -14,7 +15,7 @@
from djimaging.tables.receptivefield.rf_utils import compute_explained_rf, resize_srf, split_strf, \
compute_polarity_and_peak_idxs, merge_strf
from djimaging.utils.dj_utils import get_primary_key
from djimaging.utils.plot_utils import plot_srf, plot_trf, plot_signals_heatmap, set_long_title
from djimaging.utils.plot_utils import plot_srf, plot_trf, plot_signals_heatmap, set_long_title, prep_long_title


class SplitRFParamsTemplate(dj.Lookup):
Expand Down Expand Up @@ -200,7 +201,6 @@ def fetch1_pixel_size(self, key):

def make(self, key):
srf = (self.split_rf_table() & key).fetch1("srf")
stim_dict = (self.stimulus_table() & key).fetch1("stim_dict")

# Fit RF model
srf_fit, srf_params, qi = fit_rf_model(srf, kind='gauss', polarity=self._polarity)
Expand Down Expand Up @@ -229,10 +229,13 @@ def plot1(self, key=None):
key = get_primary_key(table=self, key=key)
srf = (self.split_rf_table() & key).fetch1("srf")
srf_fit, rf_qidx = (self & key).fetch1("srf_fit", 'rf_qidx')
srf_params = (self & key).fetch1("srf_params")

vabsmax = np.maximum(np.max(np.abs(srf)), np.max(np.abs(srf_fit)))

fig, axs = plt.subplots(1, 3, figsize=(10, 3))
fig, axs = plt.subplots(1, 4, figsize=(12, 3))

fig.suptitle(prep_long_title(key))

ax = axs[0]
plot_srf(srf, ax=ax, vabsmax=vabsmax)
Expand All @@ -246,6 +249,9 @@ def plot1(self, key=None):
plot_srf(srf - srf_fit, ax=ax, vabsmax=vabsmax)
ax.set_title(f'Difference: QI={rf_qidx:.2f}')

ax = axs[3]
plot_srf_gauss_fit(ax, srf=srf, srf_params=srf_params, vabsmax=vabsmax, plot_cb=True)

plt.tight_layout()
plt.show()

Expand Down Expand Up @@ -312,7 +318,6 @@ def fetch1_pixel_size(self, key):

def make(self, key):
srf = (self.split_rf_table() & key).fetch1("srf")
stim_dict = (self.stimulus_table() & key).fetch1("stim_dict")

srf_fit, srf_center_fit, srf_surround_fit, srf_params, eff_polarity, qi = fit_rf_model(
srf, kind='dog', polarity=self._polarity)
Expand Down Expand Up @@ -361,15 +366,17 @@ def make(self, key):
def plot1(self, key=None):
key = get_primary_key(table=self, key=key)
srf = (self.split_rf_table() & key).fetch1("srf")
srf_fit, srf_center_fit, srf_surround_fit, srf_eff_center, rf_qidx = (self & key).fetch1(
"srf_fit", 'srf_center_fit', 'srf_surround_fit', 'srf_eff_center', 'rf_qidx')
srf_fit, srf_center_fit, srf_surround_fit, srf_eff_center, rf_qidx, srf_ec_params = (self & key).fetch1(
"srf_fit", 'srf_center_fit', 'srf_surround_fit', 'srf_eff_center', 'rf_qidx', 'srf_eff_center_params')

vabsmax = np.maximum(np.max(np.abs(srf)), np.max(np.abs(srf_fit)))

fig, axs = plt.subplots(2, 3, figsize=(10, 6))

fig.suptitle(prep_long_title(key))

ax = axs[0, 0]
plot_srf(srf, ax=ax, vabsmax=vabsmax)
plot_srf_gauss_fit(ax, srf=srf, srf_params=srf_ec_params, vabsmax=vabsmax, plot_cb=True)
ax.set_title('sRF')

ax = axs[0, 1]
Expand All @@ -385,7 +392,7 @@ def plot1(self, key=None):
ax.set_title('sRF center fit')

ax = axs[1, 1]
plot_srf(srf_surround_fit, ax=ax, vabsmax=vabsmax)
plot_srf(srf_surround_fit, ax=ax, vabsmax=np.minimum(0.1 * vabsmax, np.abs(np.max(srf_surround_fit))))
ax.set_title('sRF surround fit')

ax = axs[1, 2]
Expand Down
25 changes: 25 additions & 0 deletions djimaging/utils/receptive_fields/plot_rf_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse


def plot_rf_frames(rf, rf_time, downsample=1):
Expand Down Expand Up @@ -58,3 +60,26 @@ def update(frame):

anim = FuncAnimation(fig, update, frames=np.arange(int(len(rf) * 1.5)), blit=True, interval=1000 / fps)
return HTML(anim.to_html5_video())


def plot_srf_gauss_fit(ax, srf=None, vabsmax=None, srf_params=None, n_std=2, color='k', ms=3, plot_cb=False, **kwargs):
if srf_params is not None:
ax.plot(srf_params['x_mean'], srf_params['y_mean'], zorder=100, marker='x', ms=ms, c=color, **kwargs)
ax.add_patch(Ellipse(
xy=(srf_params['x_mean'], srf_params['y_mean']),
width=n_std * 2 * srf_params['x_stddev'],
height=n_std * 2 * srf_params['y_stddev'],
angle=np.rad2deg(srf_params['theta']), color=color, fill=False, **kwargs))

if srf is not None:
if vabsmax is None:
vmin = np.min(srf)
vmax = np.max(srf)
cmap = 'gray'
else:
vmin = -vabsmax
vmax = vabsmax
cmap = 'bwr'
im = ax.imshow(srf, vmin=vmin, vmax=vmax, cmap=cmap, zorder=0, origin='lower')
if plot_cb:
plt.colorbar(im, ax=ax)

0 comments on commit 6b03adc

Please sign in to comment.