Skip to content

Commit

Permalink
Merge pull request #291 from nextstrain/opt_dependencies
Browse files Browse the repository at this point in the history
setup.py: removed matplotlib, moved ipdb to dev. augur/frequency_esti…
  • Loading branch information
rneher authored Jul 2, 2019
2 parents d1cf443 + 1163397 commit 108053e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 24 deletions.
14 changes: 12 additions & 2 deletions augur/frequency_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,12 @@ def calc_confidence(self):


def test_simple_estimator():
import matplotlib.pyplot as plt
try:
import matplotlib.pyplot as plt
except ImportError:
plot=False
print("Plotting requires a working matplotlib installation.")

tps = np.sort(100 * np.random.uniform(size=500))
freq_traj = [0.1]
stiffness=100
Expand Down Expand Up @@ -755,7 +760,12 @@ def test_simple_estimator():
return fe

def test_nested_estimator():
import matplotlib.pyplot as plt
try:
import matplotlib.pyplot as plt
except ImportError:
plot=False
print("Plotting requires a working matplotlib installation.")

tps = np.sort(100 * np.random.uniform(size=2000))
freq_traj = [0.1]
stiffness=1000
Expand Down
43 changes: 24 additions & 19 deletions augur/titer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,26 +502,31 @@ def validate(self, plot=False, cutoff=0.0, validation_set = None, fname=None):
model_performance['values'] = validation.values()

self.validation = model_performance

if plot:
import matplotlib.pyplot as plt
import seaborn as sns
fs=16
sns.set_style('darkgrid')
plt.figure()
ax = plt.subplot(111)
plt.plot([-1,6], [-1,6], 'k')
plt.scatter(actual, predicted)
plt.ylabel(r"predicted $\log_2$ distance", fontsize = fs)
plt.xlabel(r"measured $\log_2$ distance" , fontsize = fs)
ax.tick_params(axis='both', labelsize=fs)
plt.text(-2.5,6,'regularization:\nprediction error:\nR^2:', fontsize = fs-2)
plt.text(1.2,6, str(self.lam_drop)+'/'+str(self.lam_pot)+'/'+str(self.lam_avi)+' (HI/pot/avi)'
+'\n'+str(round(model_performance['abs_error'], 2))+'/'+str(round(model_performance['rms_error'], 2))+' (abs/rms)'
+ '\n' + str(model_performance['r_squared']), fontsize = fs-2)
plt.tight_layout()

if fname:
plt.savefig(fname)
try:
import matplotlib.pyplot as plt
import seaborn as sns
except ImportError:
print("Plotting requires a working matplotlib and seaborn installation.")
else:
fs=16
sns.set_style('darkgrid')
plt.figure()
ax = plt.subplot(111)
plt.plot([-1,6], [-1,6], 'k')
plt.scatter(actual, predicted)
plt.ylabel(r"predicted $\log_2$ distance", fontsize = fs)
plt.xlabel(r"measured $\log_2$ distance" , fontsize = fs)
ax.tick_params(axis='both', labelsize=fs)
plt.text(-2.5,6,'regularization:\nprediction error:\nR^2:', fontsize = fs-2)
plt.text(1.2,6, str(self.lam_drop)+'/'+str(self.lam_pot)+'/'+str(self.lam_avi)+' (HI/pot/avi)'
+'\n'+str(round(model_performance['abs_error'], 2))+'/'+str(round(model_performance['rms_error'], 2))+' (abs/rms)'
+ '\n' + str(model_performance['r_squared']), fontsize = fs-2)
plt.tight_layout()

if fname:
plt.savefig(fname)

return model_performance

Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@
"biopython >=1.73, ==1.*",
"boto >=2.38, ==2.*",
"cvxopt >=1.1.9, ==1.1.*",
"ipdb >=0.10.1",
"jsonschema ==3.0.0a3",
"matplotlib >=2.0, ==2.*",
"pandas >=0.23.4",
"phylo-treetime >=0.5.6, ==0.5.*",
"seaborn >=0.9.0",
"snakemake >=5.1.5, ==5.*"
],
extras_require={
Expand All @@ -54,6 +51,7 @@
"sphinx-argparse >=0.2.5",
"sphinx-rtd-theme >=0.4.3",
"wheel >=0.32.3, ==0.32.*",
"ipdb >=0.10.1"
]
},
classifiers=[
Expand Down

0 comments on commit 108053e

Please sign in to comment.