diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index b1f1f84570..ee67a89e68 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -201,9 +201,7 @@ def logp(self, value): return tt.zeros_like(value) def _repr_latex_(self, name=None, dist=None): - if dist is None: - dist = self - return r'${} \sim \text{Flat}()$' + return r'${} \sim \text{Flat}()$'.format(name) class HalfFlat(PositiveContinuous): @@ -220,9 +218,7 @@ def logp(self, value): return bound(tt.zeros_like(value), value > 0) def _repr_latex_(self, name=None, dist=None): - if dist is None: - dist = self - return r'${} \sim \text{{HalfFlat}()$' + return r'${} \sim \text{{HalfFlat}()$'.format(name) class Normal(Continuous): diff --git a/pymc3/model.py b/pymc3/model.py index 1c626d01cb..01ac38ddc9 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -17,7 +17,7 @@ from .theanof import gradient, hessian, inputvars, generator from .vartypes import typefilter, discrete_types, continuous_types, isgenerator from .blocking import DictToArrayBijection, ArrayOrdering -from .util import get_transformed_name +from .util import get_transformed_name, escape_latex __all__ = [ 'Model', 'Factor', 'compilef', 'fn', 'fastfn', 'modelcontext', @@ -1081,7 +1081,7 @@ def _repr_latex_(self, name=None, dist=None): name = self.name if dist is None: dist = self.distribution - return self.distribution._repr_latex_(name=name, dist=dist) + return self.distribution._repr_latex_(name=escape_latex(name), dist=dist) __latex__ = _repr_latex_ @@ -1186,7 +1186,7 @@ def _repr_latex_(self, name=None, dist=None): name = self.name if dist is None: dist = self.distribution - return self.distribution._repr_latex_(name=name, dist=dist) + return self.distribution._repr_latex_(name=escape_latex(name), dist=dist) __latex__ = _repr_latex_ @@ -1335,7 +1335,7 @@ def _repr_latex_(self, name=None, dist=None): name = self.name if dist is None: dist = self.distribution - return self.distribution._repr_latex_(name=name, dist=dist) + return self.distribution._repr_latex_(name=escape_latex(name), dist=dist) __latex__ = _repr_latex_ diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 8ca047384f..312e6a0134 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -957,11 +957,11 @@ def setup_class(self): Y_obs = Normal('Y_obs', mu=mu, sd=sigma, observed=Y) self.distributions = [alpha, sigma, mu, b, Y_obs] self.expected = ( - '$alpha \\sim \\text{Normal}(\\mathit{mu}=0, \\mathit{sd}=10.0)$', - '$sigma \\sim \\text{HalfNormal}(\\mathit{sd}=1.0)$', - '$mu \\sim \\text{Deterministic}(alpha, \\text{Constant}, beta)$', - '$beta \\sim \\text{Normal}(\\mathit{mu}=0, \\mathit{sd}=10.0)$', - '$Y_obs \\sim \\text{Normal}(\\mathit{mu}=mu, \\mathit{sd}=f(sigma))$' + r'$alpha \sim \text{Normal}(\mathit{mu}=0, \mathit{sd}=10.0)$', + r'$sigma \sim \text{HalfNormal}(\mathit{sd}=1.0)$', + r'$mu \sim \text{Deterministic}(alpha, \text{Constant}, beta)$', + r'$beta \sim \text{Normal}(\mathit{mu}=0, \mathit{sd}=10.0)$', + r'$Y\_obs \sim \text{Normal}(\mathit{mu}=mu, \mathit{sd}=f(sigma))$' ) def test__repr_latex_(self): diff --git a/pymc3/util.py b/pymc3/util.py index bfcbfe3a34..db24777dc8 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -1,5 +1,33 @@ +import re + from numpy import asscalar +LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE) + + +def escape_latex(strng): + """Consistently escape LaTeX special characters for _repr_latex_ in IPython + + Implementation taken from the IPython magic `format_latex` + + Example + ------- + escape_latex('disease_rate') # 'disease\_rate' + + Parameters + ---------- + strng : str + string to escape LaTeX characters + + Returns + ------- + str + A string with LaTeX escaped + """ + if strng is None: + return u'None' + return LATEX_ESCAPE_RE.sub(r'\\\1', strng) + def get_transformed_name(name, transform): """ @@ -14,7 +42,7 @@ def get_transformed_name(name, transform): Returns ------- - str + str A string to use for the transformed variable """ return "{}_{}__".format(name, transform.name) @@ -88,6 +116,7 @@ def get_variable_name(variable): try: names = [get_variable_name(item) for item in variable.get_parents()[0].inputs] + # do not escape_latex these, since it is not idempotent return 'f(%s)' % ','.join([n for n in names if isinstance(n, str)]) except IndexError: pass @@ -95,7 +124,7 @@ def get_variable_name(variable): if not value.shape: return asscalar(value) return 'array' - return name + return escape_latex(name) def update_start_vals(a, b, model):