Skip to content

Commit c247302

Browse files
committed
Use ScalarLoop for hyp2f1 gradient
1 parent 80f5a23 commit c247302

File tree

2 files changed

+328
-281
lines changed

2 files changed

+328
-281
lines changed

pytensor/scalar/math.py

Lines changed: 157 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
import os
8-
import warnings
98
from textwrap import dedent
109

1110
import numpy as np
@@ -26,7 +25,9 @@
2625
expm1,
2726
float64,
2827
float_types,
28+
floor,
2929
identity,
30+
integer_types,
3031
isinf,
3132
log,
3233
log1p,
@@ -853,15 +854,13 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x):
853854
s_sign = -s_sign
854855

855856
# log will cast >int16 to float64
856-
log_s_inc = log_x - log(n)
857-
if log_s_inc.type.dtype != log_s.type.dtype:
858-
log_s_inc = log_s_inc.astype(log_s.type.dtype)
859-
log_s += log_s_inc
857+
log_s += log_x - log(n)
858+
if log_s.type.dtype != dtype:
859+
log_s = log_s.astype(dtype)
860860

861-
new_log_delta = log_s - 2 * log(n + k)
862-
if new_log_delta.type.dtype != log_delta.type.dtype:
863-
new_log_delta = new_log_delta.astype(log_delta.type.dtype)
864-
log_delta = new_log_delta
861+
log_delta = log_s - 2 * log(n + k)
862+
if log_delta.type.dtype != dtype:
863+
log_delta = log_delta.astype(dtype)
865864

866865
n += 1
867866
return (
@@ -1581,9 +1580,9 @@ def grad(self, inputs, grads):
15811580
a, b, c, z = inputs
15821581
(gz,) = grads
15831582
return [
1584-
gz * hyp2f1_der(a, b, c, z, wrt=0),
1585-
gz * hyp2f1_der(a, b, c, z, wrt=1),
1586-
gz * hyp2f1_der(a, b, c, z, wrt=2),
1583+
gz * hyp2f1_grad(a, b, c, z, wrt=0),
1584+
gz * hyp2f1_grad(a, b, c, z, wrt=1),
1585+
gz * hyp2f1_grad(a, b, c, z, wrt=2),
15871586
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
15881587
]
15891588

@@ -1594,134 +1593,165 @@ def c_code(self, *args, **kwargs):
15941593
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
15951594

15961595

1597-
class Hyp2F1Der(ScalarOp):
1598-
"""
1599-
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1596+
def _unsafe_sign(x):
1597+
# Unlike scalar.sign we don't worry about x being 0 or nan
1598+
return switch(x > 0, 1, -1)
16001599

1601-
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1602-
"""
16031600

1604-
nin = 5
1601+
def hyp2f1_grad(a, b, c, z, wrt: int):
1602+
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
16051603

1606-
def impl(self, a, b, c, z, wrt):
1607-
def check_2f1_converges(a, b, c, z) -> bool:
1608-
num_terms = 0
1609-
is_polynomial = False
1604+
def check_2f1_converges(a, b, c, z):
1605+
def is_nonpositive_integer(x):
1606+
if x.type.dtype not in integer_types:
1607+
return eq(floor(x), x) & (x <= 0)
1608+
else:
1609+
return x <= 0
16101610

1611-
def is_nonpositive_integer(x):
1612-
return x <= 0 and x.is_integer()
1611+
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
1612+
num_terms = switch(
1613+
a_is_polynomial,
1614+
floor(scalar_abs(a)).astype("int64"),
1615+
0,
1616+
)
16131617

1614-
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1615-
is_polynomial = True
1616-
num_terms = int(np.floor(abs(a)))
1617-
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1618-
is_polynomial = True
1619-
num_terms = int(np.floor(abs(b)))
1618+
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
1619+
num_terms = switch(
1620+
b_is_polynomial,
1621+
floor(scalar_abs(b)).astype("int64"),
1622+
num_terms,
1623+
)
16201624

1621-
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1625+
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
1626+
is_polynomial = a_is_polynomial | b_is_polynomial
16221627

1623-
return not is_undefined and (
1624-
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1625-
)
1628+
return (~is_undefined) & (
1629+
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
1630+
)
16261631

1627-
def compute_grad_2f1(a, b, c, z, wrt):
1628-
"""
1629-
Notes
1630-
-----
1631-
The algorithm can be derived by looking at the ratio of two successive terms in the series
1632-
β_{k+1}/β_{k} = A(k)/B(k)
1633-
β_{k+1} = A(k)/B(k) * β_{k}
1634-
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1635-
1636-
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1637-
1638-
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1639-
by dropping the respective term
1640-
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1641-
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1642-
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1643-
1644-
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1645-
tracking their signs.
1646-
"""
1632+
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
1633+
"""
1634+
Notes
1635+
-----
1636+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1637+
β_{k+1}/β_{k} = A(k)/B(k)
1638+
β_{k+1} = A(k)/B(k) * β_{k}
1639+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1640+
1641+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1642+
1643+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1644+
by dropping the respective term
1645+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1646+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1647+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1648+
1649+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1650+
tracking their signs.
1651+
"""
1652+
1653+
wrt_a = wrt_b = False
1654+
if wrt == 0:
1655+
wrt_a = True
1656+
elif wrt == 1:
1657+
wrt_b = True
1658+
elif wrt != 2:
1659+
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1660+
1661+
min_steps = np.array(
1662+
10, dtype="int32"
1663+
) # https://github.com/stan-dev/math/issues/2857
1664+
max_steps = switch(
1665+
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
1666+
)
1667+
precision = np.array(1e-14, dtype=config.floatX)
16471668

1648-
wrt_a = wrt_b = False
1649-
if wrt == 0:
1650-
wrt_a = True
1651-
elif wrt == 1:
1652-
wrt_b = True
1653-
elif wrt != 2:
1654-
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1655-
1656-
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1657-
max_steps = int(1e6)
1658-
precision = 1e-14
1659-
1660-
res = 0
1661-
1662-
if z == 0:
1663-
return res
1664-
1665-
log_g_old = -np.inf
1666-
log_t_old = 0.0
1667-
log_t_new = 0.0
1668-
sign_z = np.sign(z)
1669-
log_z = np.log(np.abs(z))
1670-
1671-
log_g_old_sign = 1
1672-
log_t_old_sign = 1
1673-
log_t_new_sign = 1
1674-
sign_zk = sign_z
1675-
1676-
for k in range(max_steps):
1677-
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1678-
if p == 0:
1679-
return res
1680-
log_t_new += np.log(np.abs(p)) + log_z
1681-
log_t_new_sign = np.sign(p) * log_t_new_sign
1682-
1683-
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1684-
if wrt_a:
1685-
term += np.reciprocal(a + k)
1686-
elif wrt_b:
1687-
term += np.reciprocal(b + k)
1688-
else:
1689-
term -= np.reciprocal(c + k)
1690-
1691-
log_g_old = log_t_new + np.log(np.abs(term))
1692-
log_g_old_sign = np.sign(term) * log_t_new_sign
1693-
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1694-
res += g_current
1695-
1696-
log_t_old = log_t_new
1697-
log_t_old_sign = log_t_new_sign
1698-
sign_zk *= sign_z
1699-
1700-
if k >= min_steps and np.abs(g_current) <= precision:
1701-
return res
1702-
1703-
warnings.warn(
1704-
f"hyp2f1_der did not converge after {k} iterations",
1705-
RuntimeWarning,
1706-
)
1707-
return np.nan
1669+
grad = np.array(0, dtype=dtype)
1670+
1671+
log_g = np.array(-np.inf, dtype=dtype)
1672+
log_g_sign = np.array(1, dtype="int8")
1673+
1674+
log_t = np.array(0.0, dtype=dtype)
1675+
log_t_sign = np.array(1, dtype="int8")
1676+
1677+
log_z = log(scalar_abs(z))
1678+
sign_z = _unsafe_sign(z)
1679+
1680+
sign_zk = sign_z
1681+
k = np.array(0, dtype="int32")
1682+
1683+
def inner_loop(
1684+
grad,
1685+
log_g,
1686+
log_g_sign,
1687+
log_t,
1688+
log_t_sign,
1689+
sign_zk,
1690+
k,
1691+
a,
1692+
b,
1693+
c,
1694+
log_z,
1695+
sign_z,
1696+
):
1697+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1698+
if p.type.dtype != dtype:
1699+
p = p.astype(dtype)
1700+
1701+
term = log_g_sign * log_t_sign * exp(log_g - log_t)
1702+
if wrt_a:
1703+
term += reciprocal(a + k)
1704+
elif wrt_b:
1705+
term += reciprocal(b + k)
1706+
else:
1707+
term -= reciprocal(c + k)
1708+
1709+
if term.type.dtype != dtype:
1710+
term = term.astype(dtype)
1711+
1712+
log_t = log_t + log(scalar_abs(p)) + log_z
1713+
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
1714+
log_g = log_t + log(scalar_abs(term))
1715+
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
1716+
1717+
g_current = log_g_sign * exp(log_g) * sign_zk
17081718

1709-
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1710-
if not check_2f1_converges(a, b, c, z):
1711-
warnings.warn(
1712-
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1713-
RuntimeWarning,
1719+
# If p==0, don't update grad and get out of while loop next
1720+
grad = switch(
1721+
eq(p, 0),
1722+
grad,
1723+
grad + g_current,
17141724
)
1715-
return np.nan
17161725

1717-
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1726+
sign_zk *= sign_z
1727+
k += 1
17181728

1719-
def __call__(self, a, b, c, z, wrt, **kwargs):
1720-
# This allows wrt to be a keyword argument
1721-
return super().__call__(a, b, c, z, wrt, **kwargs)
1729+
return (
1730+
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k),
1731+
(eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))),
1732+
)
17221733

1723-
def c_code(self, *args, **kwargs):
1724-
raise NotImplementedError()
1734+
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k]
1735+
constant = [a, b, c, log_z, sign_z]
1736+
grad = _make_scalar_loop(
1737+
max_steps, init, constant, inner_loop, name="hyp2f1_grad"
1738+
)
17251739

1740+
return switch(
1741+
eq(z, 0),
1742+
0,
1743+
grad,
1744+
)
17261745

1727-
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
1746+
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
1747+
z_is_zero = eq(z, 0)
1748+
converges = check_2f1_converges(a, b, c, z)
1749+
return switch(
1750+
z_is_zero,
1751+
0,
1752+
switch(
1753+
converges,
1754+
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
1755+
np.nan,
1756+
),
1757+
)

0 commit comments

Comments
 (0)