1
- import inspect
2
1
from functools import singledispatch
3
2
from numbers import Number
4
3
from textwrap import indent
25
24
use_optimized_cheap_pass ,
26
25
)
27
26
from pytensor .link .numba .dispatch .helpers import check_broadcasting , tuple_mapper
28
- from pytensor .link .utils import (
29
- compile_function_src ,
30
- get_name_for_object ,
31
- unique_name_generator ,
32
- )
27
+ from pytensor .link .utils import compile_function_src , get_name_for_object
33
28
from pytensor .scalar .basic import (
34
29
AND ,
35
30
OR ,
@@ -447,7 +442,7 @@ def _vectorize_bc(
447
442
noalias_outputs = False ,
448
443
):
449
444
450
- flags = True
445
+ flags = True
451
446
{
452
447
"arcp" , # Allow Reciprocal
453
448
"contract" , # Allow floating-point contraction
@@ -470,8 +465,6 @@ def loop_call(typingctx, *args):
470
465
scalar_func , [in_type .dtype for in_type in in_types ], {}
471
466
)
472
467
473
- ndim = iter_shape_type .count
474
-
475
468
sig = types .void (types .StarArgTuple ([* out_types , * in_types , iter_shape_type ]))
476
469
477
470
def codegen (context , builder , signature , args ):
@@ -491,7 +484,10 @@ def codegen(context, builder, signature, args):
491
484
# Lower the code of the scalar function so that we can use it in the inner loop
492
485
# Caching is set to false to avoid a numba bug TODO ref?
493
486
inner_func = context .compile_subroutine (
494
- builder , scalar_func , scalar_signature , caching = False ,
487
+ builder ,
488
+ scalar_func ,
489
+ scalar_signature ,
490
+ caching = False ,
495
491
)
496
492
inner = inner_func .fndesc
497
493
@@ -508,12 +504,12 @@ def extract_array(aryty, ary):
508
504
# TODO I think this is better than the noalias attribute
509
505
# for the input, but self_ref isn't supported in a released
510
506
# llvmlite version yet
511
- #mod = builder.module
512
- #domain = mod.add_metadata([], self_ref=True)
513
- #input_scope = mod.add_metadata([domain], self_ref=True)
514
- #output_scope = mod.add_metadata([domain], self_ref=True)
515
- #input_scope_set = mod.add_metadata([input_scope, output_scope])
516
- #output_scope_set = mod.add_metadata([input_scope, output_scope])
507
+ # mod = builder.module
508
+ # domain = mod.add_metadata([], self_ref=True)
509
+ # input_scope = mod.add_metadata([domain], self_ref=True)
510
+ # output_scope = mod.add_metadata([domain], self_ref=True)
511
+ # input_scope_set = mod.add_metadata([input_scope, output_scope])
512
+ # output_scope_set = mod.add_metadata([input_scope, output_scope])
517
513
518
514
inputs = [
519
515
extract_array (aryty , ary )
@@ -559,8 +555,8 @@ def extract_array(aryty, ary):
559
555
context , builder , * array_info , idxs_bc , * safe
560
556
)
561
557
val = builder .load (ptr )
562
- #val.set_metadata("alias.scope", input_scope_set)
563
- #val.set_metadata("noalias", output_scope_set)
558
+ # val.set_metadata("alias.scope", input_scope_set)
559
+ # val.set_metadata("noalias", output_scope_set)
564
560
input_vals .append (val )
565
561
566
562
# Call scalar function
@@ -581,13 +577,13 @@ def extract_array(aryty, ary):
581
577
):
582
578
if accu is not None :
583
579
load = builder .load (accu )
584
- #load.set_metadata("alias.scope", output_scope_set)
585
- #load.set_metadata("noalias", input_scope_set)
580
+ # load.set_metadata("alias.scope", output_scope_set)
581
+ # load.set_metadata("noalias", input_scope_set)
586
582
new_value = builder .fadd (load , value )
587
- store = builder .store (new_value , accu )
588
- # TODO ?
589
- #store.set_metadata("alias.scope", output_scope_set)
590
- #store.set_metadata("noalias", input_scope_set)
583
+ builder .store (new_value , accu )
584
+ # TODO belongs to noalias scope
585
+ # store.set_metadata("alias.scope", output_scope_set)
586
+ # store.set_metadata("noalias", input_scope_set)
591
587
else :
592
588
idxs_bc = [
593
589
zero if bc else idx
@@ -596,10 +592,10 @@ def extract_array(aryty, ary):
596
592
ptr = cgutils .get_item_pointer2 (
597
593
context , builder , * outputs [i ], idxs_bc
598
594
)
599
- #store = builder.store(value, ptr)
600
- store = arrayobj .store_item (context , builder , out_types [i ], value , ptr )
601
- #store.set_metadata("alias.scope", output_scope_set)
602
- #store.set_metadata("noalias", input_scope_set)
595
+ # store = builder.store(value, ptr)
596
+ arrayobj .store_item (context , builder , out_types [i ], value , ptr )
597
+ # store.set_metadata("alias.scope", output_scope_set)
598
+ # store.set_metadata("noalias", input_scope_set)
603
599
604
600
# Close the loops and write accumulator values to the output arrays
605
601
for depth , loop in enumerate (loop_stack [::- 1 ]):
@@ -615,12 +611,14 @@ def extract_array(aryty, ary):
615
611
context , builder , * outputs [output ], idxs_bc
616
612
)
617
613
load = builder .load (accu )
618
- #load.set_metadata("alias.scope", output_scope_set)
619
- #load.set_metadata("noalias", input_scope_set)
620
- #store = builder.store(load, ptr)
621
- store = arrayobj .store_item (context , builder , out_types [output ], load , ptr )
622
- #store.set_metadata("alias.scope", output_scope_set)
623
- #store.set_metadata("noalias", input_scope_set)
614
+ # load.set_metadata("alias.scope", output_scope_set)
615
+ # load.set_metadata("noalias", input_scope_set)
616
+ # store = builder.store(load, ptr)
617
+ arrayobj .store_item (
618
+ context , builder , out_types [output ], load , ptr
619
+ )
620
+ # store.set_metadata("alias.scope", output_scope_set)
621
+ # store.set_metadata("noalias", input_scope_set)
624
622
loop .__exit__ (None , None , None )
625
623
return
626
624
@@ -672,6 +670,7 @@ def make_output(iter_shape, bc, dtype):
672
670
673
671
check_arrays = check_broadcasting
674
672
else :
673
+
675
674
@numba .extending .register_jitable
676
675
def make_output (iter_shape , bc , dtype ):
677
676
return np .empty ((), dtype )
@@ -680,7 +679,6 @@ def make_output(iter_shape, bc, dtype):
680
679
def check_arrays (a , b , c ):
681
680
pass
682
681
683
-
684
682
make_outputs = tuple_mapper (make_output )
685
683
686
684
def impl (* inputs ):
@@ -756,10 +754,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
756
754
scalar_op_fn ,
757
755
input_bc_patterns ,
758
756
output_bc_patterns ,
759
- output_dtypes = tuple ([
760
- variable .dtype
761
- for variable in node .outputs
762
- ]),
757
+ output_dtypes = tuple ([variable .dtype for variable in node .outputs ]),
763
758
)
764
759
765
760
# TODO We should do this in vectorize instead
@@ -771,13 +766,16 @@ def elemwise_inplace(*inputs):
771
766
outputs = vectorized (* inputs )
772
767
for out_idx , in_idx in literal_unroll (pattern ):
773
768
inputs [in_idx ][...] = outputs [out_idx ]
769
+
774
770
else :
775
771
elemwise_inplace = vectorized
776
772
777
773
if len (node .outputs ) == 1 :
774
+
778
775
@numba_njit
779
776
def elemwise_wrapper (* inputs ):
780
777
return elemwise_inplace (* inputs )[0 ]
778
+
781
779
else :
782
780
elemwise_wrapper = vectorized
783
781
0 commit comments