Skip to content

Commit cc4fab6

Browse files
committed
.progress
1 parent 3b173b3 commit cc4fab6

File tree

7 files changed

+126
-38
lines changed

7 files changed

+126
-38
lines changed

pytensor/scan/rewriting.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,14 +2561,14 @@ def scan_push_out_dot1(fgraph, node):
25612561
position=1,
25622562
)
25632563

2564-
scan_seqopt1.register(
2565-
"scan_decompose_compound_ops,
2566-
in2out(scan_decompose_compound_ops),
2567-
"fast_run",
2568-
"scan",
2569-
"scan_pushout",
2570-
position=2,
2571-
)
2564+
# scan_seqopt1.register(
2565+
# "scan_decompose_compound_ops",
2566+
# in2out(scan_decompose_compound_ops),
2567+
# "fast_run",
2568+
# "scan",
2569+
# "scan_pushout",
2570+
# position=2,
2571+
# )
25722572

25732573
scan_seqopt1.register(
25742574
"scan_push_out_non_seq",

pytensor/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
114114

115115

116116
# isort: off
117+
import pytensor.tensor._linalg
117118
from pytensor.tensor import linalg
118119
from pytensor.tensor import special
119120
from pytensor.tensor import signal
@@ -143,6 +144,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
143144
specify_shape,
144145
)
145146

147+
146148
# We import as `_shared` instead of `shared` to avoid confusion between
147149
# `pytensor.shared` and `tensor._shared`.
148150
from pytensor.tensor.sort import argsort, sort

pytensor/tensor/_linalg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Register rewrites
2+
import pytensor.tensor._linalg.solve
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Register rewrites in the database
2+
import pytensor.tensor._linalg.solve.rewrites
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor._linalg.solve.tridiagonal import split_solve_tridiagonal, decompose_of_solve_tridiagonal
3+
from pytensor.tensor.blockwise import Blockwise
4+
from pytensor.tensor.elemwise import DimShuffle
5+
from pytensor.tensor.rewriting.basic import register_specialize
6+
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
7+
from pytensor.tensor.slinalg import Solve
8+
9+
10+
@register_specialize
11+
@node_rewriter(tracks=[Blockwise])
12+
def batched_solve_decomposition(fgraph, node):
13+
if not(isinstance(node.op.core_op, Solve) and node.op.core_op.assume_a == "tridiagonal"):
14+
return
15+
16+
a, b = node.inputs
17+
[out] = node.outputs
18+
batch_ndim = node.op.batch_ndim(node)
19+
20+
# Check if a is broadcasted in computing the output
21+
if not any(
22+
a_bcast and not b_bcast
23+
for a_bcast, b_bcast
24+
in zip(a.type.broadcastable[:batch_ndim], b.type.broadcastable[:batch_ndim], strict=True)
25+
):
26+
return
27+
28+
new_out = split_solve_tridiagonal(node)
29+
return [new_out]
30+
31+
32+
@register_specialize
33+
@node_rewriter([Blockwise])
34+
def reuse_lu_decomp_multiple_solves(fgraph, node):
35+
36+
if not isinstance(node.op.core_op, Solve):
37+
return None
38+
39+
assume_a = node.op.core_op.assume_a
40+
41+
if assume_a != "tridiagonal":
42+
# Other assume_a not yet supported
43+
return None
44+
45+
def find_solve_clients(var):
46+
return [
47+
cl
48+
for cl, idx in fgraph.clients[var]
49+
if idx == 0
50+
and isinstance(cl.op, Blockwise)
51+
and isinstance(cl.op.core_op, Solve)
52+
and cl.op.core_op.assume_a == assume_a
53+
]
54+
55+
56+
[A, _] = node.inputs
57+
if A.owner is not None and isinstance(A.owner.op, DimShuffle):
58+
# FIXME: Don't consider if dimshuffle mixes batch and core dims
59+
[A] = A.owner.inputs
60+
61+
# Find Solve using A
62+
A_solve_clients = [(client, False) for client in find_solve_clients(A)]
63+
64+
# Find Solves using A.T
65+
for cl, _ in fgraph.clients[A]:
66+
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
67+
A_T = cl.out
68+
A_solve_clients.extend((client, True) for client in find_solve_clients(A_T))
69+
70+
A_decomp = decompose_of_solve_tridiagonal(A)
71+
replacements = {}
72+
for client, transpose in A_solve_clients:
73+
_, b = client.inputs
74+
return replacements

pytensor/tensor/_linalg/solve/tridiagonal.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
from scipy.linalg import get_lapack_funcs
44

55
from pytensor.graph import Op, Apply
6-
from pytensor.tensor import as_tensor, tensor, diagonal
6+
from pytensor.tensor.basic import as_tensor, diagonal
7+
from pytensor.tensor.type import tensor, vector
78
from pytensor.tensor.blockwise import Blockwise
9+
from pytensor.tensor.slinalg import Solve
810

911

1012
class LUFactorTridiagonal(Op):
1113
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
1214
__props__ = ("overwrite_dl", "overwrite_d", "overwrite_du",)
13-
_gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
15+
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
1416

1517
def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
1618
self.overwrite_dl = overwrite_dl
@@ -19,33 +21,34 @@ def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
1921
super().__init__()
2022

2123
def make_node(self, dl, d, du):
22-
dl, d, du = map(as_tensor, dl, d, du)
24+
dl, d, du = map(as_tensor, (dl, d, du))
2325

24-
if not all(inp.type.ndim == 1 for inp in (dl, d, du))
26+
if not all(inp.type.ndim == 1 for inp in (dl, d, du)):
2527
raise ValueError("Diagonals must be vectors")
2628

2729
ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
2830
n = (
2931
ndl + 1
3032
if ndl is not None else (
31-
n if n is not None else (
32-
ndu + 1 if nu is not None else None
33+
nd if nd is not None else (
34+
ndu + 1 if ndu is not None else None
3335
)
3436
)
3537
)
3638
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
3739
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
3840
outputs = [
39-
vector(shape=(shape=(None if n is None else n - 1,), dtype=out_dtype),
41+
vector(shape=(None if n is None else (n - 1),), dtype=out_dtype),
4042
vector(shape=(n,), dtype=out_dtype),
4143
vector(shape=(None if n is None else n - 1,), dtype=out_dtype),
4244
vector(shape=(None if n is None else n - 2,), dtype=out_dtype),
4345
vector(shape=(n,), dtype=np.int32),
4446
]
47+
return Apply(self, [dl, d, du], outputs)
4548

4649
def perform(self, node, inputs, output_storage):
4750
gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype)
48-
dl, d, du, du2, ipiv, _ = _gttrf(
51+
dl, d, du, du2, ipiv, _ = gttrf(
4952
*inputs,
5053
overwrite_dl=self.overwrite_dl,
5154
overwrite_d=self.overwrite_d,
@@ -68,26 +71,26 @@ def __init__(self, b_ndim: int, overwrite_b=False):
6871
self.b_ndim = b_ndim
6972
self.overwrite_b = overwrite_b
7073
if b_ndim == 1:
71-
_gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)
74+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
7275
else:
73-
_gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)
76+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
7477

7578
def make_node(self, dl, d, du, du2, ipiv, b):
76-
dl, d, du, du2, ipiv, b = map(as_tensor, dl, d, du, du2, ipiv, b)
79+
dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b))
7780

7881
if b.type.ndim != self.b_ndim:
7982
raise ValueError("Wrang number of dimensions for input b.")
8083

81-
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv))
84+
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
8285
raise ValueError("Inputs must be vectors")
8386

8487
ndl, nd, ndu, ndu2, nipiv = (inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv))
8588
nb = b.type.shape[0]
8689
n = (
8790
ndl + 1
8891
if ndl is not None else (
89-
n if n is not None else (
90-
ndu + 1 if nu is not None else (
92+
nd if nd is not None else (
93+
ndu + 1 if ndu is not None else (
9194
ndu2 + 2 if ndu2 is not None else (
9295
nipiv if nipiv is not None else nb
9396
)
@@ -101,14 +104,14 @@ def make_node(self, dl, d, du, du2, ipiv, b):
101104
if self.b_ndim == 1:
102105
output_shape = (n,)
103106
else:
104-
output_shape = (n, n.type.shape[-1])
107+
output_shape = (n, b.type.shape[-1])
105108

106-
outputs = [vector(shape=output_shape, dtype=out_dtype)]
109+
outputs = [tensor(shape=output_shape, dtype=out_dtype)]
107110
return Apply(self, [dl, d, du, du2, ipiv, b], outputs)
108111

109112
def perform(self, node, inputs, output_storage):
110113
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
111-
x, _ = _gttrs(
114+
x, _ = gttrs(
112115
*inputs, overwrite_b=self.overwrite_b
113116
)
114117
output_storage[0][0] = x
@@ -149,7 +152,7 @@ def make_node(self, dl, d, du, b):
149152
return Apply(self, [dl, d, du, b], [out])
150153

151154
def L_op(self, node, inputs, outputs, output_grads):
152-
# TODO
155+
pass
153156

154157
def perform(self, node, inputs, output_storage):
155158
[dl, d, du, b] = inputs
@@ -193,8 +196,13 @@ def split_solve_tridiagonal(node):
193196
"""
194197
assert isinstance(node.op, Blockwise)
195198
core_op = node.op.core_op
196-
assert isinstance(core_op, Solve) and core.op.assume_a == "tridiagonal"
199+
assert isinstance(core_op, Solve) and core_op.assume_a == "tridiagonal"
197200
a, b = node.inputs
201+
dl, d, du, du2, ipiv = decompose_of_solve_tridiagonal(a)
202+
return Blockwise(SolveLUFactorTridiagonal(b_ndim=node.op.core_op.b_ndim))(dl, d, du, du2, ipiv, b)
203+
204+
def decompose_of_solve_tridiagonal(a):
205+
# Return the decomposition of A implied by a solve tridiagonal
198206
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
199207
dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du)
200-
return Blockwise(SolveLUFactorTridiagonal(b_ndim=node.op.core.op.b_ndim))(dl, d, du)(dl, d, du, du2, ipiv)
208+
return dl, d, du, du2, ipiv

pytensor/tensor/slinalg.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,7 @@ class Solve(SolveBase):
905905

906906
def __init__(self, *, assume_a="gen", **kwargs):
907907
# Triangular and diagonal are handled outside of Solve
908-
valid_options = ["gen", "sym", "her", "pos", "banded"]
908+
valid_options = ["gen", "sym", "her", "pos", "banded", "tridiagonal"]
909909

910910
assume_a = assume_a.lower()
911911
# We use the old names as the different dispatches are more likely to support them
@@ -922,7 +922,7 @@ def __init__(self, *, assume_a="gen", **kwargs):
922922
f"Invalid assume_a: {assume_a}. It must be one of {valid_options} or {list(long_to_short.keys())}"
923923
)
924924

925-
if assume_a == "banded":
925+
if assume_a in ("tridiagonal", "banded"):
926926
from scipy import __version__ as sp_version
927927

928928
if tuple(map(int, sp_version.split(".")[:-1])) < (1, 15):
@@ -1043,14 +1043,14 @@ def solve(
10431043
b_ndim=b_ndim,
10441044
)
10451045

1046-
elif assume_a == "tridiagonal":
1047-
from pytensor.tensor._linalg.solve.tridiagonal import (
1048-
solve_tridiagonal_from_full_A_b,
1049-
)
1050-
1051-
return solve_tridiagonal_from_full_A_b(
1052-
a, b, b_ndim=b_ndim, transposed=transposed
1053-
)
1046+
# elif assume_a == "tridiagonal":
1047+
# from pytensor.tensor._linalg.solve.tridiagonal import (
1048+
# solve_tridiagonal_from_full_A_b,
1049+
# )
1050+
#
1051+
# return solve_tridiagonal_from_full_A_b(
1052+
# a, b, b_ndim=b_ndim, transposed=transposed
1053+
# )
10541054

10551055
elif assume_a == "diagonal":
10561056
a_diagonal = diagonal(a, axis1=-2, axis2=-1)

0 commit comments

Comments
 (0)