3
3
from scipy .linalg import get_lapack_funcs
4
4
5
5
from pytensor .graph import Op , Apply
6
- from pytensor .tensor import as_tensor , tensor , diagonal
6
+ from pytensor .tensor .basic import as_tensor , diagonal
7
+ from pytensor .tensor .type import tensor , vector
7
8
from pytensor .tensor .blockwise import Blockwise
9
+ from pytensor .tensor .slinalg import Solve
8
10
9
11
10
12
class LUFactorTridiagonal (Op ):
11
13
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
12
14
__props__ = ("overwrite_dl" , "overwrite_d" , "overwrite_du" ,)
13
- _gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
15
+ gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
14
16
15
17
def __init__ (self , overwrite_dl = False , overwrite_d = False , overwrite_du = False ):
16
18
self .overwrite_dl = overwrite_dl
@@ -19,33 +21,34 @@ def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
19
21
super ().__init__ ()
20
22
21
23
def make_node (self , dl , d , du ):
22
- dl , d , du = map (as_tensor , dl , d , du )
24
+ dl , d , du = map (as_tensor , ( dl , d , du ) )
23
25
24
- if not all (inp .type .ndim == 1 for inp in (dl , d , du ))
26
+ if not all (inp .type .ndim == 1 for inp in (dl , d , du )):
25
27
raise ValueError ("Diagonals must be vectors" )
26
28
27
29
ndl , nd , ndu = (inp .type .shape [- 1 ] for inp in (dl , d , du ))
28
30
n = (
29
31
ndl + 1
30
32
if ndl is not None else (
31
- n if n is not None else (
32
- ndu + 1 if nu is not None else None
33
+ nd if nd is not None else (
34
+ ndu + 1 if ndu is not None else None
33
35
)
34
36
)
35
37
)
36
38
dummy_arrays = [np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du )]
37
39
out_dtype = get_lapack_funcs ("gttrf" , dummy_arrays ).dtype
38
40
outputs = [
39
- vector (shape = (shape = ( None if n is None else n - 1 ,), dtype = out_dtype ),
41
+ vector (shape = (None if n is None else ( n - 1 ) ,), dtype = out_dtype ),
40
42
vector (shape = (n ,), dtype = out_dtype ),
41
43
vector (shape = (None if n is None else n - 1 ,), dtype = out_dtype ),
42
44
vector (shape = (None if n is None else n - 2 ,), dtype = out_dtype ),
43
45
vector (shape = (n ,), dtype = np .int32 ),
44
46
]
47
+ return Apply (self , [dl , d , du ], outputs )
45
48
46
49
def perform (self , node , inputs , output_storage ):
47
50
gttrf = get_lapack_funcs ("gttrf" , dtype = node .outputs [0 ].type .dtype )
48
- dl , d , du , du2 , ipiv , _ = _gttrf (
51
+ dl , d , du , du2 , ipiv , _ = gttrf (
49
52
* inputs ,
50
53
overwrite_dl = self .overwrite_dl ,
51
54
overwrite_d = self .overwrite_d ,
@@ -68,26 +71,26 @@ def __init__(self, b_ndim: int, overwrite_b=False):
68
71
self .b_ndim = b_ndim
69
72
self .overwrite_b = overwrite_b
70
73
if b_ndim == 1 :
71
- _gufunc_signature = "(dl ),(d ),(dl ),(du2 ),(d ),(d )- > (d )
74
+ self . gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
72
75
else :
73
- _gufunc_signature = "(dl ),(d ),(dl ),(du2 ),(d ),(d ,rhs )- > (d ,rhs )
76
+ self . gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
74
77
75
78
def make_node (self , dl , d , du , du2 , ipiv , b ):
76
- dl , d , du , du2 , ipiv , b = map (as_tensor , dl , d , du , du2 , ipiv , b )
79
+ dl , d , du , du2 , ipiv , b = map (as_tensor , ( dl , d , du , du2 , ipiv , b ) )
77
80
78
81
if b .type .ndim != self .b_ndim :
79
82
raise ValueError ("Wrang number of dimensions for input b." )
80
83
81
- if not all (inp .type .ndim == 1 for inp in (dl , d , du , du2 , ipiv ))
84
+ if not all (inp .type .ndim == 1 for inp in (dl , d , du , du2 , ipiv )):
82
85
raise ValueError ("Inputs must be vectors" )
83
86
84
87
ndl , nd , ndu , ndu2 , nipiv = (inp .type .shape [- 1 ] for inp in (dl , d , du , du2 , ipiv ))
85
88
nb = b .type .shape [0 ]
86
89
n = (
87
90
ndl + 1
88
91
if ndl is not None else (
89
- n if n is not None else (
90
- ndu + 1 if nu is not None else (
92
+ nd if nd is not None else (
93
+ ndu + 1 if ndu is not None else (
91
94
ndu2 + 2 if ndu2 is not None else (
92
95
nipiv if nipiv is not None else nb
93
96
)
@@ -101,14 +104,14 @@ def make_node(self, dl, d, du, du2, ipiv, b):
101
104
if self .b_ndim == 1 :
102
105
output_shape = (n ,)
103
106
else :
104
- output_shape = (n , n .type .shape [- 1 ])
107
+ output_shape = (n , b .type .shape [- 1 ])
105
108
106
- outputs = [vector (shape = output_shape , dtype = out_dtype )]
109
+ outputs = [tensor (shape = output_shape , dtype = out_dtype )]
107
110
return Apply (self , [dl , d , du , du2 , ipiv , b ], outputs )
108
111
109
112
def perform (self , node , inputs , output_storage ):
110
113
gttrs = get_lapack_funcs ("gttrs" , dtype = node .outputs [0 ].type .dtype )
111
- x , _ = _gttrs (
114
+ x , _ = gttrs (
112
115
* inputs , overwrite_b = self .overwrite_b
113
116
)
114
117
output_storage [0 ][0 ] = x
@@ -149,7 +152,7 @@ def make_node(self, dl, d, du, b):
149
152
return Apply (self , [dl , d , du , b ], [out ])
150
153
151
154
def L_op (self , node , inputs , outputs , output_grads ):
152
- # TODO
155
+ pass
153
156
154
157
def perform (self , node , inputs , output_storage ):
155
158
[dl , d , du , b ] = inputs
@@ -193,8 +196,13 @@ def split_solve_tridiagonal(node):
193
196
"""
194
197
assert isinstance (node .op , Blockwise )
195
198
core_op = node .op .core_op
196
- assert isinstance (core_op , Solve ) and core . op .assume_a == "tridiagonal"
199
+ assert isinstance (core_op , Solve ) and core_op .assume_a == "tridiagonal"
197
200
a , b = node .inputs
201
+ dl , d , du , du2 , ipiv = decompose_of_solve_tridiagonal (a )
202
+ return Blockwise (SolveLUFactorTridiagonal (b_ndim = node .op .core_op .b_ndim ))(dl , d , du , du2 , ipiv , b )
203
+
204
+ def decompose_of_solve_tridiagonal (a ):
205
+ # Return the decomposition of A implied by a solve tridiagonal
198
206
dl , d , du = (diagonal (a , offset = o , axis1 = - 2 , axis2 = - 1 ) for o in (- 1 , 0 , 1 ))
199
207
dl , d , du , du2 , ipiv = Blockwise (LUFactorTridiagonal ())(dl , d , du )
200
- return Blockwise ( SolveLUFactorTridiagonal ( b_ndim = node . op . core . op . b_ndim ))( dl , d , du )( dl , d , du , du2 , ipiv )
208
+ return dl , d , du , du2 , ipiv
0 commit comments