Skip to content

Commit 840b5c5

Browse files
author
jax authors
committed
Merge pull request #18499 from renecotyfanboy:hyp1f1_poch
PiperOrigin-RevId: 582765493
2 parents 946819f + 47ca51f commit 840b5c5

File tree

4 files changed

+186
-1
lines changed

4 files changed

+186
-1
lines changed

Diff for: docs/jax.scipy.rst

+2
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ jax.scipy.special
147147
gammainc
148148
gammaincc
149149
gammaln
150+
hyp1f1
150151
i0
151152
i0e
152153
i1
@@ -159,6 +160,7 @@ jax.scipy.special
159160
multigammaln
160161
ndtr
161162
ndtri
163+
poch
162164
polygamma
163165
spence
164166
sph_harm

Diff for: jax/_src/scipy/special.py

+180
Original file line numberDiff line numberDiff line change
@@ -1719,3 +1719,183 @@ def bernoulli(n: int) -> Array:
17191719
k = jnp.arange(2, 50, dtype=bn.dtype) # Choose 50 because 2 ** -50 < 1E-15
17201720
q2 = jnp.sum(k[:, None] ** -m[None, :], axis=0)
17211721
return bn.at[4::2].set(q1 * (1 + q2))
1722+
1723+
1724+
@custom_derivatives.custom_jvp
1725+
@_wraps(osp_special.poch, module='scipy.special', lax_description="""\
1726+
The JAX version only accepts positive and real inputs.""")
1727+
def poch(z: ArrayLike, m: ArrayLike) -> Array:
1728+
# Factorial definition when m is close to an integer, otherwise gamma definition.
1729+
z, m = promote_args_inexact("poch", z, m)
1730+
1731+
return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z))
1732+
1733+
1734+
def _poch_z_derivative(z, m):
1735+
"""
1736+
Defined in :
1737+
https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/01/
1738+
"""
1739+
1740+
return (digamma(z + m) - digamma(z)) * poch(z, m)
1741+
1742+
1743+
def _poch_m_derivative(z, m):
1744+
"""
1745+
Defined in :
1746+
https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/02/
1747+
"""
1748+
1749+
return digamma(z + m) * poch(z, m)
1750+
1751+
1752+
poch.defjvps(
1753+
lambda z_dot, primal_out, z, m: _poch_z_derivative(z, m) * z_dot,
1754+
lambda m_dot, primal_out, z, m: _poch_m_derivative(z, m) * m_dot,
1755+
)
1756+
1757+
1758+
def _hyp1f1_serie(a, b, x):
1759+
"""
1760+
Compute the 1F1 hypergeometric function using the taylor expansion
1761+
See Eq. 3.2 and associated method (a) from PEARSON, OLVER & PORTER 2014
1762+
https://doi.org/10.48550/arXiv.1407.7786
1763+
"""
1764+
1765+
def body(state):
1766+
serie, k, term = state
1767+
serie += term
1768+
term *= (a + k) / (b + k) * x / (k + 1)
1769+
k += 1
1770+
1771+
return serie, k, term
1772+
1773+
def cond(state):
1774+
serie, k, term = state
1775+
1776+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
1777+
1778+
init = 1, 1, a / b * x
1779+
1780+
return lax.while_loop(cond, body, init)[0]
1781+
1782+
1783+
def _hyp1f1_asymptotic(a, b, x):
1784+
"""
1785+
Compute the 1F1 hypergeometric function using asymptotic expansion
1786+
See Eq. 3.8 and simplification for real inputs from PEARSON, OLVER & PORTER 2014
1787+
https://doi.org/10.48550/arXiv.1407.7786
1788+
"""
1789+
1790+
def body(state):
1791+
serie, k, term = state
1792+
serie += term
1793+
term *= (b - a + k) * (1 - a + k) / (k + 1) / x
1794+
k += 1
1795+
1796+
return serie, k, term
1797+
1798+
def cond(state):
1799+
serie, k, term = state
1800+
1801+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
1802+
1803+
init = 1, 1, (b - a) * (1 - a) / x
1804+
serie = lax.while_loop(cond, body, init)[0]
1805+
1806+
return gamma(b) / gamma(a) * lax.exp(x) * x ** (a - b) * serie
1807+
1808+
1809+
@jit
1810+
@jnp.vectorize
1811+
def _hyp1f1_a_derivative(a, b, x):
1812+
"""
1813+
Define it as a serie using :
1814+
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
1815+
"""
1816+
1817+
def body(state):
1818+
serie, k, term = state
1819+
serie += term * (digamma(a + k) - digamma(a))
1820+
term *= (a + k) / (b + k) * x / (k + 1)
1821+
k += 1
1822+
1823+
return serie, k, term
1824+
1825+
def cond(state):
1826+
serie, k, term = state
1827+
1828+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
1829+
1830+
init = 0, 1, a / b * x
1831+
1832+
return lax.while_loop(cond, body, init)[0]
1833+
1834+
1835+
@jit
1836+
@jnp.vectorize
1837+
def _hyp1f1_b_derivative(a, b, x):
1838+
"""
1839+
Define it as a serie using :
1840+
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
1841+
"""
1842+
1843+
def body(state):
1844+
serie, k, term = state
1845+
serie += term * (digamma(b) - digamma(b + k))
1846+
term *= (a + k) / (b + k) * x / (k + 1)
1847+
k += 1
1848+
1849+
return serie, k, term
1850+
1851+
def cond(state):
1852+
serie, k, term = state
1853+
1854+
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
1855+
1856+
init = 0, 1, a / b * x
1857+
1858+
return lax.while_loop(cond, body, init)[0]
1859+
1860+
1861+
@jit
1862+
def _hyp1f1_x_derivative(a, b, x):
1863+
"""
1864+
Define it as a serie using :
1865+
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/04/
1866+
"""
1867+
1868+
return a / b * hyp1f1(a + 1, b + 1, x)
1869+
1870+
1871+
@custom_derivatives.custom_jvp
1872+
@jit
1873+
@jnp.vectorize
1874+
@_wraps(osp_special.hyp1f1, module='scipy.special', lax_description="""\
1875+
The JAX version only accepts positive and real inputs. Values of a, b and x
1876+
leading to high values of 1F1 might be erroneous, considering enabling double
1877+
precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""")
1878+
def hyp1f1(a, b, x):
1879+
"""
1880+
Implementation of the 1F1 hypergeometric function for real valued inputs
1881+
Backed by https://doi.org/10.48550/arXiv.1407.7786
1882+
There is room for improvement in the implementation using recursion to
1883+
evaluate lower values of hyp1f1 when a or b or both are > 60-80
1884+
"""
1885+
a, b, x = promote_args_inexact('hyp1f1', a, b, x)
1886+
1887+
result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x)
1888+
index = (a == 0) * 1 + ((a == b) & (a != 0)) * 2 + ((b == 0) & (a != 0)) * 3
1889+
1890+
return lax.select_n(index,
1891+
result,
1892+
jnp.array(1, dtype=x.dtype),
1893+
jnp.exp(x),
1894+
jnp.array(jnp.inf, dtype=x.dtype))
1895+
1896+
1897+
hyp1f1.defjvps(
1898+
lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot,
1899+
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
1900+
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
1901+
)

Diff for: jax/scipy/special.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,6 @@
5454
zeta as zeta,
5555
kl_div as kl_div,
5656
rel_entr as rel_entr,
57+
poch as poch,
58+
hyp1f1 as hyp1f1,
5759
)

Diff for: tests/lax_scipy_special_functions_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
141141
op_record(
142142
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
143143
),
144-
144+
op_record("poch", 2, float_dtypes, jtu.rand_positive, True),
145+
op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True)
145146
]
146147

147148

0 commit comments

Comments
 (0)