Skip to content

Commit 3b173b3

Browse files
committed
Decompose Tridiagonal Solve into core steps
1 parent 4378d48 commit 3b173b3

File tree

10 files changed

+389
-70
lines changed

10 files changed

+389
-70
lines changed

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from numba.core.extending import overload
55
from numba.np.linalg import ensure_lapack
66
from numpy import ndarray
7-
from scipy import linalg
87

98
from pytensor.link.numba.dispatch.basic import numba_njit
109
from pytensor.link.numba.dispatch.linalg._LAPACK import (
@@ -13,11 +12,9 @@
1312
int_ptr_to_val,
1413
val_to_int_ptr,
1514
)
16-
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
1715
from pytensor.link.numba.dispatch.linalg.utils import (
1816
_check_scipy_linalg_matrix,
1917
_copy_to_fortran_order_even_if_1d,
20-
_solve_check,
2118
_trans_char_to_int,
2219
)
2320

@@ -227,72 +224,42 @@ def impl(
227224

228225

229226
def _solve_tridiagonal(
230-
a: ndarray,
231-
b: ndarray,
232-
lower: bool,
233-
overwrite_a: bool,
227+
dl: ndarray,
228+
d: ndarray,
229+
ul: ndarray,
230+
B: ndarray,
234231
overwrite_b: bool,
235-
check_finite: bool,
236-
transposed: bool,
237232
):
238233
"""
239-
Solve a positive-definite linear system using the Cholesky decomposition.
234+
Solve a tridiagonal linear system.
240235
"""
241-
return linalg.solve(
242-
a=a,
243-
b=b,
244-
lower=lower,
245-
overwrite_a=overwrite_a,
246-
overwrite_b=overwrite_b,
247-
check_finite=check_finite,
248-
transposed=transposed,
249-
assume_a="tridiagonal",
250-
)
236+
return
251237

252238

253239
@overload(_solve_tridiagonal)
254-
def _tridiagonal_solve_impl(
255-
A: ndarray,
240+
def _solve_tridiagonal_impl(
241+
dl: ndarray,
242+
d: ndarray,
243+
du: ndarray,
256244
B: ndarray,
257-
lower: bool,
258-
overwrite_a: bool,
259245
overwrite_b: bool,
260-
check_finite: bool,
261-
transposed: bool,
262-
) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]:
246+
) -> Callable[[ndarray, ndarray, ndarray, ndarray, bool], ndarray]:
263247
ensure_lapack()
264-
_check_scipy_linalg_matrix(A, "solve")
248+
_check_scipy_linalg_matrix(dl, "solve_")
249+
_check_scipy_linalg_matrix(dl, "solve")
250+
_check_scipy_linalg_matrix(dl, "solve")
265251
_check_scipy_linalg_matrix(B, "solve")
266252

267253
def impl(
268-
A: ndarray,
254+
dl: ndarray,
255+
d: ndarray,
256+
du: ndarray,
269257
B: ndarray,
270-
lower: bool,
271-
overwrite_a: bool,
272258
overwrite_b: bool,
273-
check_finite: bool,
274-
transposed: bool,
275259
) -> ndarray:
276-
n = np.int32(A.shape[-1])
277-
_solve_check_input_shapes(A, B)
278-
norm = "1"
279-
280-
if transposed:
281-
A = A.T
282-
dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1)
283-
284-
anorm = tridiagonal_norm(du, d, dl)
285-
286-
dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du)
287-
_solve_check(n, INFO)
288-
289-
X, INFO = _gttrs(
290-
dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b
291-
)
292-
_solve_check(n, INFO)
260+
dl, d, du, du2, IPIV, _ = _gttrf(dl, d, du)
293261

294-
RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm)
295-
_solve_check(n, INFO, True, RCOND)
262+
X, _ = _gttrs(dl, d, du, du2, IPIV, B, trans=0, overwrite_b=overwrite_b)
296263

297264
return X
298265

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
1818
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
1919
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
20+
from pytensor.tensor._linalg.solve.tridiagonal import SolveTridiagonal
2021
from pytensor.tensor.slinalg import (
2122
LU,
2223
BlockDiagonal,
@@ -215,8 +216,6 @@ def numba_funcify_Solve(op, node, **kwargs):
215216
solve_fn = _solve_symmetric
216217
elif assume_a == "pos":
217218
solve_fn = _solve_psd
218-
elif assume_a == "tridiagonal":
219-
solve_fn = _solve_tridiagonal
220219
else:
221220
warnings.warn(
222221
f"Numba assume_a={assume_a} not implemented. Falling back to general solve.\n"
@@ -284,6 +283,32 @@ def solve_triangular(a, b):
284283
return solve_triangular
285284

286285

286+
@numba_funcify.register(SolveTridiagonal)
287+
def numba_funcify_SolveTridiagonal(op, node, **kwargs):
288+
overwrite_b = op.overwrite_b
289+
b_ndim = op.b_ndim
290+
291+
dtype = node.inputs[0].dtype
292+
if dtype in complex_dtypes:
293+
raise NotImplementedError(
294+
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
295+
)
296+
297+
@numba_njit
298+
def solve_triangular(dl, d, du, b):
299+
res = _solve_tridiagonal(
300+
dl,
301+
d,
302+
du,
303+
b,
304+
overwrite_b=overwrite_b,
305+
)
306+
307+
return res
308+
309+
return solve_triangular
310+
311+
287312
@numba_funcify.register(CholeskySolve)
288313
def numba_funcify_CholeskySolve(op, node, **kwargs):
289314
lower = op.lower

pytensor/scan/rewriting.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,6 +2561,14 @@ def scan_push_out_dot1(fgraph, node):
25612561
position=1,
25622562
)
25632563

2564+
scan_seqopt1.register(
2565+
"scan_decompose_compound_ops,
2566+
in2out(scan_decompose_compound_ops),
2567+
"fast_run",
2568+
"scan",
2569+
"scan_pushout",
2570+
position=2,
2571+
)
25642572

25652573
scan_seqopt1.register(
25662574
"scan_push_out_non_seq",
@@ -2569,18 +2577,17 @@ def scan_push_out_dot1(fgraph, node):
25692577
"fast_run",
25702578
"scan",
25712579
"scan_pushout",
2572-
position=2,
2580+
position=3,
25732581
)
25742582

2575-
25762583
scan_seqopt1.register(
25772584
"scan_push_out_seq",
25782585
in2out(scan_push_out_seq, ignore_newtrees=True),
25792586
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
25802587
"fast_run",
25812588
"scan",
25822589
"scan_pushout",
2583-
position=3,
2590+
position=4,
25842591
)
25852592

25862593

@@ -2592,7 +2599,7 @@ def scan_push_out_dot1(fgraph, node):
25922599
"more_mem",
25932600
"scan",
25942601
"scan_pushout",
2595-
position=4,
2602+
position=5,
25962603
)
25972604

25982605

@@ -2605,7 +2612,7 @@ def scan_push_out_dot1(fgraph, node):
26052612
"more_mem",
26062613
"scan",
26072614
"scan_pushout",
2608-
position=5,
2615+
position=6,
26092616
)
26102617

26112618
scan_eqopt2.register(

pytensor/tensor/_linalg/__init__.py

Whitespace-only changes.

pytensor/tensor/_linalg/solve/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)