@@ -1719,3 +1719,183 @@ def bernoulli(n: int) -> Array:
1719
1719
k = jnp .arange (2 , 50 , dtype = bn .dtype ) # Choose 50 because 2 ** -50 < 1E-15
1720
1720
q2 = jnp .sum (k [:, None ] ** - m [None , :], axis = 0 )
1721
1721
return bn .at [4 ::2 ].set (q1 * (1 + q2 ))
1722
+
1723
+
1724
+ @custom_derivatives .custom_jvp
1725
+ @_wraps (osp_special .poch , module = 'scipy.special' , lax_description = """\
1726
+ The JAX version only accepts positive and real inputs.""" )
1727
+ def poch (z : ArrayLike , m : ArrayLike ) -> Array :
1728
+ # Factorial definition when m is close to an integer, otherwise gamma definition.
1729
+ z , m = promote_args_inexact ("poch" , z , m )
1730
+
1731
+ return jnp .where (m == 0. , jnp .array (1 , dtype = z .dtype ), gamma (z + m ) / gamma (z ))
1732
+
1733
+
1734
+ def _poch_z_derivative (z , m ):
1735
+ """
1736
+ Defined in :
1737
+ https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/01/
1738
+ """
1739
+
1740
+ return (digamma (z + m ) - digamma (z )) * poch (z , m )
1741
+
1742
+
1743
+ def _poch_m_derivative (z , m ):
1744
+ """
1745
+ Defined in :
1746
+ https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/02/
1747
+ """
1748
+
1749
+ return digamma (z + m ) * poch (z , m )
1750
+
1751
+
1752
+ poch .defjvps (
1753
+ lambda z_dot , primal_out , z , m : _poch_z_derivative (z , m ) * z_dot ,
1754
+ lambda m_dot , primal_out , z , m : _poch_m_derivative (z , m ) * m_dot ,
1755
+ )
1756
+
1757
+
1758
+ def _hyp1f1_serie (a , b , x ):
1759
+ """
1760
+ Compute the 1F1 hypergeometric function using the taylor expansion
1761
+ See Eq. 3.2 and associated method (a) from PEARSON, OLVER & PORTER 2014
1762
+ https://doi.org/10.48550/arXiv.1407.7786
1763
+ """
1764
+
1765
+ def body (state ):
1766
+ serie , k , term = state
1767
+ serie += term
1768
+ term *= (a + k ) / (b + k ) * x / (k + 1 )
1769
+ k += 1
1770
+
1771
+ return serie , k , term
1772
+
1773
+ def cond (state ):
1774
+ serie , k , term = state
1775
+
1776
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-8 )
1777
+
1778
+ init = 1 , 1 , a / b * x
1779
+
1780
+ return lax .while_loop (cond , body , init )[0 ]
1781
+
1782
+
1783
+ def _hyp1f1_asymptotic (a , b , x ):
1784
+ """
1785
+ Compute the 1F1 hypergeometric function using asymptotic expansion
1786
+ See Eq. 3.8 and simplification for real inputs from PEARSON, OLVER & PORTER 2014
1787
+ https://doi.org/10.48550/arXiv.1407.7786
1788
+ """
1789
+
1790
+ def body (state ):
1791
+ serie , k , term = state
1792
+ serie += term
1793
+ term *= (b - a + k ) * (1 - a + k ) / (k + 1 ) / x
1794
+ k += 1
1795
+
1796
+ return serie , k , term
1797
+
1798
+ def cond (state ):
1799
+ serie , k , term = state
1800
+
1801
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-8 )
1802
+
1803
+ init = 1 , 1 , (b - a ) * (1 - a ) / x
1804
+ serie = lax .while_loop (cond , body , init )[0 ]
1805
+
1806
+ return gamma (b ) / gamma (a ) * lax .exp (x ) * x ** (a - b ) * serie
1807
+
1808
+
1809
+ @jit
1810
+ @jnp .vectorize
1811
+ def _hyp1f1_a_derivative (a , b , x ):
1812
+ """
1813
+ Define it as a serie using :
1814
+ https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
1815
+ """
1816
+
1817
+ def body (state ):
1818
+ serie , k , term = state
1819
+ serie += term * (digamma (a + k ) - digamma (a ))
1820
+ term *= (a + k ) / (b + k ) * x / (k + 1 )
1821
+ k += 1
1822
+
1823
+ return serie , k , term
1824
+
1825
+ def cond (state ):
1826
+ serie , k , term = state
1827
+
1828
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-15 )
1829
+
1830
+ init = 0 , 1 , a / b * x
1831
+
1832
+ return lax .while_loop (cond , body , init )[0 ]
1833
+
1834
+
1835
+ @jit
1836
+ @jnp .vectorize
1837
+ def _hyp1f1_b_derivative (a , b , x ):
1838
+ """
1839
+ Define it as a serie using :
1840
+ https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
1841
+ """
1842
+
1843
+ def body (state ):
1844
+ serie , k , term = state
1845
+ serie += term * (digamma (b ) - digamma (b + k ))
1846
+ term *= (a + k ) / (b + k ) * x / (k + 1 )
1847
+ k += 1
1848
+
1849
+ return serie , k , term
1850
+
1851
+ def cond (state ):
1852
+ serie , k , term = state
1853
+
1854
+ return (k < 250 ) & (lax .abs (term ) / lax .abs (serie ) > 1e-15 )
1855
+
1856
+ init = 0 , 1 , a / b * x
1857
+
1858
+ return lax .while_loop (cond , body , init )[0 ]
1859
+
1860
+
1861
+ @jit
1862
+ def _hyp1f1_x_derivative (a , b , x ):
1863
+ """
1864
+ Define it as a serie using :
1865
+ https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/04/
1866
+ """
1867
+
1868
+ return a / b * hyp1f1 (a + 1 , b + 1 , x )
1869
+
1870
+
1871
+ @custom_derivatives .custom_jvp
1872
+ @jit
1873
+ @jnp .vectorize
1874
+ @_wraps (osp_special .hyp1f1 , module = 'scipy.special' , lax_description = """\
1875
+ The JAX version only accepts positive and real inputs. Values of a, b and x
1876
+ leading to high values of 1F1 might be erroneous, considering enabling double
1877
+ precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""" )
1878
+ def hyp1f1 (a , b , x ):
1879
+ """
1880
+ Implementation of the 1F1 hypergeometric function for real valued inputs
1881
+ Backed by https://doi.org/10.48550/arXiv.1407.7786
1882
+ There is room for improvement in the implementation using recursion to
1883
+ evaluate lower values of hyp1f1 when a or b or both are > 60-80
1884
+ """
1885
+ a , b , x = promote_args_inexact ('hyp1f1' , a , b , x )
1886
+
1887
+ result = lax .cond (lax .abs (x ) < 100 , _hyp1f1_serie , _hyp1f1_asymptotic , a , b , x )
1888
+ index = (a == 0 ) * 1 + ((a == b ) & (a != 0 )) * 2 + ((b == 0 ) & (a != 0 )) * 3
1889
+
1890
+ return lax .select_n (index ,
1891
+ result ,
1892
+ jnp .array (1 , dtype = x .dtype ),
1893
+ jnp .exp (x ),
1894
+ jnp .array (jnp .inf , dtype = x .dtype ))
1895
+
1896
+
1897
+ hyp1f1 .defjvps (
1898
+ lambda a_dot , primal_out , a , b , x : _hyp1f1_a_derivative (a , b , x ) * a_dot ,
1899
+ lambda b_dot , primal_out , a , b , x : _hyp1f1_b_derivative (a , b , x ) * b_dot ,
1900
+ lambda x_dot , primal_out , a , b , x : _hyp1f1_x_derivative (a , b , x ) * x_dot
1901
+ )
0 commit comments