5
5
"""
6
6
7
7
import os
8
- import warnings
9
8
from textwrap import dedent
10
9
11
10
import numpy as np
26
25
expm1 ,
27
26
float64 ,
28
27
float_types ,
28
+ floor ,
29
29
identity ,
30
+ integer_types ,
30
31
isinf ,
31
32
log ,
32
33
log1p ,
@@ -853,15 +854,13 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
853
854
s_sign = - s_sign
854
855
855
856
# log will cast >int16 to float64
856
- log_s_inc = log_x - log (n )
857
- if log_s_inc .type .dtype != log_s .type .dtype :
858
- log_s_inc = log_s_inc .astype (log_s .type .dtype )
859
- log_s += log_s_inc
857
+ log_s += log_x - log (n )
858
+ if log_s .type .dtype != dtype :
859
+ log_s = log_s .astype (dtype )
860
860
861
- new_log_delta = log_s - 2 * log (n + k )
862
- if new_log_delta .type .dtype != log_delta .type .dtype :
863
- new_log_delta = new_log_delta .astype (log_delta .type .dtype )
864
- log_delta = new_log_delta
861
+ log_delta = log_s - 2 * log (n + k )
862
+ if log_delta .type .dtype != dtype :
863
+ log_delta = log_delta .astype (dtype )
865
864
866
865
n += 1
867
866
return (
@@ -1581,9 +1580,9 @@ def grad(self, inputs, grads):
1581
1580
a , b , c , z = inputs
1582
1581
(gz ,) = grads
1583
1582
return [
1584
- gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1585
- gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1586
- gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1583
+ gz * hyp2f1_grad (a , b , c , z , wrt = 0 ),
1584
+ gz * hyp2f1_grad (a , b , c , z , wrt = 1 ),
1585
+ gz * hyp2f1_grad (a , b , c , z , wrt = 2 ),
1587
1586
gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1588
1587
]
1589
1588
@@ -1594,134 +1593,165 @@ def c_code(self, *args, **kwargs):
1594
1593
hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1595
1594
1596
1595
1597
- class Hyp2F1Der ( ScalarOp ):
1598
- """
1599
- Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1596
+ def _unsafe_sign ( x ):
1597
+ # Unlike scalar.sign we don't worry about x being 0 or nan
1598
+ return switch ( x > 0 , 1 , - 1 )
1600
1599
1601
- Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1602
- """
1603
1600
1604
- nin = 5
1601
+ def hyp2f1_grad (a , b , c , z , wrt : int ):
1602
+ dtype = upcast (a .type .dtype , b .type .dtype , c .type .dtype , z .type .dtype , "float32" )
1605
1603
1606
- def impl (self , a , b , c , z , wrt ):
1607
- def check_2f1_converges (a , b , c , z ) -> bool :
1608
- num_terms = 0
1609
- is_polynomial = False
1604
+ def check_2f1_converges (a , b , c , z ):
1605
+ def is_nonpositive_integer (x ):
1606
+ if x .type .dtype not in integer_types :
1607
+ return eq (floor (x ), x ) & (x <= 0 )
1608
+ else :
1609
+ return x <= 0
1610
1610
1611
- def is_nonpositive_integer (x ):
1612
- return x <= 0 and x .is_integer ()
1611
+ a_is_polynomial = is_nonpositive_integer (a ) & (scalar_abs (a ) >= 0 )
1612
+ num_terms = switch (
1613
+ a_is_polynomial ,
1614
+ floor (scalar_abs (a )).astype ("int64" ),
1615
+ 0 ,
1616
+ )
1613
1617
1614
- if is_nonpositive_integer (a ) and abs ( a ) >= num_terms :
1615
- is_polynomial = True
1616
- num_terms = int ( np . floor ( abs ( a )))
1617
- if is_nonpositive_integer ( b ) and abs ( b ) >= num_terms :
1618
- is_polynomial = True
1619
- num_terms = int ( np . floor ( abs ( b )) )
1618
+ b_is_polynomial = is_nonpositive_integer (b ) & ( scalar_abs ( b ) >= num_terms )
1619
+ num_terms = switch (
1620
+ b_is_polynomial ,
1621
+ floor ( scalar_abs ( b )). astype ( "int64" ),
1622
+ num_terms ,
1623
+ )
1620
1624
1621
- is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1625
+ is_undefined = is_nonpositive_integer (c ) & (scalar_abs (c ) <= num_terms )
1626
+ is_polynomial = a_is_polynomial | b_is_polynomial
1622
1627
1623
- return not is_undefined and (
1624
- is_polynomial or np . abs ( z ) < 1 or ( np . abs ( z ) == 1 and c > (a + b ))
1625
- )
1628
+ return ( ~ is_undefined ) & (
1629
+ is_polynomial | ( scalar_abs ( z ) < 1 ) | ( eq ( scalar_abs ( z ), 1 ) & ( c > (a + b ) ))
1630
+ )
1626
1631
1627
- def compute_grad_2f1 (a , b , c , z , wrt ):
1628
- """
1629
- Notes
1630
- -----
1631
- The algorithm can be derived by looking at the ratio of two successive terms in the series
1632
- β_{k+1}/β_{k} = A(k)/B(k)
1633
- β_{k+1} = A(k)/B(k) * β_{k}
1634
- d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1635
-
1636
- In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1637
-
1638
- The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639
- by dropping the respective term
1640
- d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641
- d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642
- d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643
-
1644
- The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645
- tracking their signs.
1646
- """
1632
+ def compute_grad_2f1 (a , b , c , z , wrt , skip_loop ):
1633
+ """
1634
+ Notes
1635
+ -----
1636
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1637
+ β_{k+1}/β_{k} = A(k)/B(k)
1638
+ β_{k+1} = A(k)/B(k) * β_{k}
1639
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1640
+
1641
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1642
+
1643
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1644
+ by dropping the respective term
1645
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1646
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1647
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1648
+
1649
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1650
+ tracking their signs.
1651
+ """
1652
+
1653
+ wrt_a = wrt_b = False
1654
+ if wrt == 0 :
1655
+ wrt_a = True
1656
+ elif wrt == 1 :
1657
+ wrt_b = True
1658
+ elif wrt != 2 :
1659
+ raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1660
+
1661
+ min_steps = np .array (
1662
+ 10 , dtype = "int32"
1663
+ ) # https://github.com/stan-dev/math/issues/2857
1664
+ max_steps = switch (
1665
+ skip_loop , np .array (0 , dtype = "int32" ), np .array (int (1e6 ), dtype = "int32" )
1666
+ )
1667
+ precision = np .array (1e-14 , dtype = config .floatX )
1647
1668
1648
- wrt_a = wrt_b = False
1649
- if wrt == 0 :
1650
- wrt_a = True
1651
- elif wrt == 1 :
1652
- wrt_b = True
1653
- elif wrt != 2 :
1654
- raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1655
-
1656
- min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1657
- max_steps = int (1e6 )
1658
- precision = 1e-14
1659
-
1660
- res = 0
1661
-
1662
- if z == 0 :
1663
- return res
1664
-
1665
- log_g_old = - np .inf
1666
- log_t_old = 0.0
1667
- log_t_new = 0.0
1668
- sign_z = np .sign (z )
1669
- log_z = np .log (np .abs (z ))
1670
-
1671
- log_g_old_sign = 1
1672
- log_t_old_sign = 1
1673
- log_t_new_sign = 1
1674
- sign_zk = sign_z
1675
-
1676
- for k in range (max_steps ):
1677
- p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1678
- if p == 0 :
1679
- return res
1680
- log_t_new += np .log (np .abs (p )) + log_z
1681
- log_t_new_sign = np .sign (p ) * log_t_new_sign
1682
-
1683
- term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1684
- if wrt_a :
1685
- term += np .reciprocal (a + k )
1686
- elif wrt_b :
1687
- term += np .reciprocal (b + k )
1688
- else :
1689
- term -= np .reciprocal (c + k )
1690
-
1691
- log_g_old = log_t_new + np .log (np .abs (term ))
1692
- log_g_old_sign = np .sign (term ) * log_t_new_sign
1693
- g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1694
- res += g_current
1695
-
1696
- log_t_old = log_t_new
1697
- log_t_old_sign = log_t_new_sign
1698
- sign_zk *= sign_z
1699
-
1700
- if k >= min_steps and np .abs (g_current ) <= precision :
1701
- return res
1702
-
1703
- warnings .warn (
1704
- f"hyp2f1_der did not converge after { k } iterations" ,
1705
- RuntimeWarning ,
1706
- )
1707
- return np .nan
1669
+ grad = np .array (0 , dtype = dtype )
1670
+
1671
+ log_g = np .array (- np .inf , dtype = dtype )
1672
+ log_g_sign = np .array (1 , dtype = "int8" )
1673
+
1674
+ log_t = np .array (0.0 , dtype = dtype )
1675
+ log_t_sign = np .array (1 , dtype = "int8" )
1676
+
1677
+ log_z = log (scalar_abs (z ))
1678
+ sign_z = _unsafe_sign (z )
1679
+
1680
+ sign_zk = sign_z
1681
+ k = np .array (0 , dtype = "int32" )
1682
+
1683
+ def inner_loop (
1684
+ grad ,
1685
+ log_g ,
1686
+ log_g_sign ,
1687
+ log_t ,
1688
+ log_t_sign ,
1689
+ sign_zk ,
1690
+ k ,
1691
+ a ,
1692
+ b ,
1693
+ c ,
1694
+ log_z ,
1695
+ sign_z ,
1696
+ ):
1697
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1698
+ if p .type .dtype != dtype :
1699
+ p = p .astype (dtype )
1700
+
1701
+ term = log_g_sign * log_t_sign * exp (log_g - log_t )
1702
+ if wrt_a :
1703
+ term += reciprocal (a + k )
1704
+ elif wrt_b :
1705
+ term += reciprocal (b + k )
1706
+ else :
1707
+ term -= reciprocal (c + k )
1708
+
1709
+ if term .type .dtype != dtype :
1710
+ term = term .astype (dtype )
1711
+
1712
+ log_t = log_t + log (scalar_abs (p )) + log_z
1713
+ log_t_sign = (_unsafe_sign (p ) * log_t_sign ).astype ("int8" )
1714
+ log_g = log_t + log (scalar_abs (term ))
1715
+ log_g_sign = (_unsafe_sign (term ) * log_t_sign ).astype ("int8" )
1716
+
1717
+ g_current = log_g_sign * exp (log_g ) * sign_zk
1708
1718
1709
- # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1710
- if not check_2f1_converges ( a , b , c , z ):
1711
- warnings . warn (
1712
- f"Hyp2F1 does not meet convergence conditions with given arguments a= { a } , b= { b } , c= { c } , z= { z } " ,
1713
- RuntimeWarning ,
1719
+ # If p==0, don't update grad and get out of while loop next
1720
+ grad = switch (
1721
+ eq ( p , 0 ),
1722
+ grad ,
1723
+ grad + g_current ,
1714
1724
)
1715
- return np .nan
1716
1725
1717
- return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1726
+ sign_zk *= sign_z
1727
+ k += 1
1718
1728
1719
- def __call__ (self , a , b , c , z , wrt , ** kwargs ):
1720
- # This allows wrt to be a keyword argument
1721
- return super ().__call__ (a , b , c , z , wrt , ** kwargs )
1729
+ return (
1730
+ (grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ),
1731
+ (eq (p , 0 ) | ((k > min_steps ) & (scalar_abs (g_current ) <= precision ))),
1732
+ )
1722
1733
1723
- def c_code (self , * args , ** kwargs ):
1724
- raise NotImplementedError ()
1734
+ init = [grad , log_g , log_g_sign , log_t , log_t_sign , sign_zk , k ]
1735
+ constant = [a , b , c , log_z , sign_z ]
1736
+ grad = _make_scalar_loop (
1737
+ max_steps , init , constant , inner_loop , name = "hyp2f1_grad"
1738
+ )
1725
1739
1740
+ return switch (
1741
+ eq (z , 0 ),
1742
+ 0 ,
1743
+ grad ,
1744
+ )
1726
1745
1727
- hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
1746
+ # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1747
+ z_is_zero = eq (z , 0 )
1748
+ converges = check_2f1_converges (a , b , c , z )
1749
+ return switch (
1750
+ z_is_zero ,
1751
+ 0 ,
1752
+ switch (
1753
+ converges ,
1754
+ compute_grad_2f1 (a , b , c , z , wrt , skip_loop = z_is_zero | (~ converges )),
1755
+ np .nan ,
1756
+ ),
1757
+ )
0 commit comments