1
1
from collections .abc import Iterable , Sequence
2
+ from typing import cast
2
3
3
4
import numpy as np
4
5
5
6
from pytensor import Variable
6
- from pytensor .graph import Constant , node_rewriter
7
- from pytensor .graph .rewriting .basic import copy_stack_trace
7
+ from pytensor .graph import Constant , FunctionGraph , node_rewriter
8
+ from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
8
9
from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
9
10
from pytensor .scalar import basic as ps
10
11
from pytensor .tensor .basic import (
11
12
Alloc ,
13
+ Join ,
12
14
MakeVector ,
13
15
alloc ,
14
16
as_tensor ,
15
17
expand_dims ,
16
18
get_underlying_scalar_constant_value ,
19
+ join ,
17
20
register_infer_shape ,
18
21
)
19
22
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
44
47
)
45
48
from pytensor .tensor .type import TensorType
46
49
from pytensor .tensor .type_other import SliceType
50
+ from pytensor .tensor .variable import TensorVariable
47
51
48
52
49
53
def _dims_dropped_by_basic_index (idxs : Sequence [slice | int ]) -> tuple [int , ...]:
@@ -66,6 +70,41 @@ def _axis_is_indexed_by_basic_index(
66
70
return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
67
71
68
72
73
+ def _lift_subtensor_non_axis (
74
+ local_subtensor_lift_rewrite : NodeRewriter ,
75
+ fgraph : FunctionGraph ,
76
+ variable : TensorVariable ,
77
+ idx_tuple : tuple [int | slice ],
78
+ axis : int ,
79
+ old_subtensor_variable : TensorVariable ,
80
+ ) -> None | list [TensorVariable ]:
81
+ # Apply generic subtensor lift rewrite along "non-axis" dimensions
82
+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
83
+ if len (real_indices ) > 1 and variable .type .ndim > 1 :
84
+ # Split the subtensor
85
+ idx_to_keep = idx_tuple [axis ]
86
+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
87
+
88
+ # Lift the non-axis indexes by calling the rewrite itself
89
+ indexed_variable = variable [idxs_to_lift ]
90
+ [indexed_variable ] = cast (
91
+ list [TensorVariable ],
92
+ local_subtensor_lift_rewrite .transform (fgraph , indexed_variable .owner ),
93
+ )
94
+ copy_stack_trace ([old_subtensor_variable , indexed_variable ], indexed_variable )
95
+
96
+ # Then reintroduce the axis index
97
+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
98
+ new_axis = axis - ndim_reduced_left
99
+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
100
+ new_out = indexed_variable [idxs_to_keep ]
101
+ copy_stack_trace (old_subtensor_variable , new_out )
102
+ return [new_out ]
103
+
104
+ else :
105
+ return None
106
+
107
+
69
108
@register_canonicalize
70
109
@register_stabilize
71
110
@register_specialize
@@ -297,29 +336,14 @@ def local_subtensor_of_softmax(fgraph, node):
297
336
if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
298
337
# If there are more dimensions being indexed, we can split them
299
338
# And lift the non-axis indexes while keeping the axis index
300
- real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
301
- if len (real_indices ) > 1 and sm .type .ndim > 1 :
302
- # Split the subtensor
303
- idx_to_keep = idx_tuple [axis ]
304
- idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
305
-
306
- # Lift the non-axis indexes by calling the rewrite itself
307
- opt_sm = sm [idxs_to_lift ]
308
- [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
309
- copy_stack_trace ([old_out , sm ], opt_sm )
310
-
311
- # Then reintroduce the axis index
312
- ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
313
- idx_tuple , axis
314
- )
315
- new_axis = axis - ndim_reduced_left
316
- idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
317
- new_out = opt_sm [idxs_to_keep ]
318
- copy_stack_trace (old_out , new_out )
319
- return [new_out ]
320
-
321
- else :
322
- return None
339
+ return _lift_subtensor_non_axis (
340
+ local_subtensor_lift_rewrite = local_subtensor_of_softmax ,
341
+ fgraph = fgraph ,
342
+ variable = sm ,
343
+ idx_tuple = idx_tuple ,
344
+ axis = axis ,
345
+ old_subtensor_variable = old_out ,
346
+ )
323
347
324
348
# Index input to softmax
325
349
x_sub = x [idx_tuple ]
@@ -646,6 +670,52 @@ def local_subtensor_make_vector(fgraph, node):
646
670
pass
647
671
648
672
673
+ @register_canonicalize
674
+ @register_specialize
675
+ @node_rewriter ([Subtensor ])
676
+ def local_subtensor_of_join (fgraph , node ):
677
+ """Lift a Subtensor through a Join.
678
+
679
+ join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
680
+ join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
681
+
682
+ """
683
+ join_var , * idx = node .inputs
684
+
685
+ if not (join_var .owner and isinstance (join_var .owner .op , Join )):
686
+ return None
687
+
688
+ if len (fgraph .clients [join_var ]) > 1 :
689
+ # Join involves a full_copy, so we don't want to do it twice
690
+ return None
691
+
692
+ join_axis , * join_components = join_var .owner .inputs
693
+
694
+ # Rewrite only works when the join axis is a constant along a non-indexed dimension
695
+ if not isinstance (join_axis , Constant ):
696
+ return None
697
+
698
+ [old_out ] = node .outputs
699
+ axis = normalize_axis_index (join_axis .data , join_components [0 ].type .ndim )
700
+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
701
+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
702
+ return _lift_subtensor_non_axis (
703
+ local_subtensor_lift_rewrite = local_subtensor_of_join ,
704
+ fgraph = fgraph ,
705
+ variable = join_var ,
706
+ idx_tuple = idx_tuple ,
707
+ axis = axis ,
708
+ old_subtensor_variable = old_out ,
709
+ )
710
+
711
+ # Lift index to the Join components
712
+ indexed_components = [component [idx_tuple ] for component in join_components ]
713
+ new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index (idx_tuple , axis )
714
+ out = join (new_axis , * indexed_components )
715
+
716
+ return [out ]
717
+
718
+
649
719
@register_specialize
650
720
@register_canonicalize
651
721
@node_rewriter ([Subtensor ])
0 commit comments