Skip to content

Commit 089db94

Browse files
committed
Run black
1 parent 02ea662 commit 089db94

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
from functools import singledispatch
32
from numbers import Number
43
from textwrap import indent
@@ -25,11 +24,7 @@
2524
use_optimized_cheap_pass,
2625
)
2726
from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper
28-
from pytensor.link.utils import (
29-
compile_function_src,
30-
get_name_for_object,
31-
unique_name_generator,
32-
)
27+
from pytensor.link.utils import compile_function_src, get_name_for_object
3328
from pytensor.scalar.basic import (
3429
AND,
3530
OR,
@@ -447,7 +442,7 @@ def _vectorize_bc(
447442
noalias_outputs=False,
448443
):
449444

450-
flags = True
445+
flags = True
451446
{
452447
"arcp", # Allow Reciprocal
453448
"contract", # Allow floating-point contraction
@@ -470,8 +465,6 @@ def loop_call(typingctx, *args):
470465
scalar_func, [in_type.dtype for in_type in in_types], {}
471466
)
472467

473-
ndim = iter_shape_type.count
474-
475468
sig = types.void(types.StarArgTuple([*out_types, *in_types, iter_shape_type]))
476469

477470
def codegen(context, builder, signature, args):
@@ -491,7 +484,10 @@ def codegen(context, builder, signature, args):
491484
# Lower the code of the scalar function so that we can use it in the inner loop
492485
# Caching is set to false to avoid a numba bug TODO ref?
493486
inner_func = context.compile_subroutine(
494-
builder, scalar_func, scalar_signature, caching=False,
487+
builder,
488+
scalar_func,
489+
scalar_signature,
490+
caching=False,
495491
)
496492
inner = inner_func.fndesc
497493

@@ -508,12 +504,12 @@ def extract_array(aryty, ary):
508504
# TODO I think this is better than the noalias attribute
509505
# for the input, but self_ref isn't supported in a released
510506
# llvmlite version yet
511-
#mod = builder.module
512-
#domain = mod.add_metadata([], self_ref=True)
513-
#input_scope = mod.add_metadata([domain], self_ref=True)
514-
#output_scope = mod.add_metadata([domain], self_ref=True)
515-
#input_scope_set = mod.add_metadata([input_scope, output_scope])
516-
#output_scope_set = mod.add_metadata([input_scope, output_scope])
507+
# mod = builder.module
508+
# domain = mod.add_metadata([], self_ref=True)
509+
# input_scope = mod.add_metadata([domain], self_ref=True)
510+
# output_scope = mod.add_metadata([domain], self_ref=True)
511+
# input_scope_set = mod.add_metadata([input_scope, output_scope])
512+
# output_scope_set = mod.add_metadata([input_scope, output_scope])
517513

518514
inputs = [
519515
extract_array(aryty, ary)
@@ -559,8 +555,8 @@ def extract_array(aryty, ary):
559555
context, builder, *array_info, idxs_bc, *safe
560556
)
561557
val = builder.load(ptr)
562-
#val.set_metadata("alias.scope", input_scope_set)
563-
#val.set_metadata("noalias", output_scope_set)
558+
# val.set_metadata("alias.scope", input_scope_set)
559+
# val.set_metadata("noalias", output_scope_set)
564560
input_vals.append(val)
565561

566562
# Call scalar function
@@ -581,13 +577,13 @@ def extract_array(aryty, ary):
581577
):
582578
if accu is not None:
583579
load = builder.load(accu)
584-
#load.set_metadata("alias.scope", output_scope_set)
585-
#load.set_metadata("noalias", input_scope_set)
580+
# load.set_metadata("alias.scope", output_scope_set)
581+
# load.set_metadata("noalias", input_scope_set)
586582
new_value = builder.fadd(load, value)
587-
store = builder.store(new_value, accu)
588-
# TODO ?
589-
#store.set_metadata("alias.scope", output_scope_set)
590-
#store.set_metadata("noalias", input_scope_set)
583+
builder.store(new_value, accu)
584+
# TODO belongs to noalias scope
585+
# store.set_metadata("alias.scope", output_scope_set)
586+
# store.set_metadata("noalias", input_scope_set)
591587
else:
592588
idxs_bc = [
593589
zero if bc else idx
@@ -596,10 +592,10 @@ def extract_array(aryty, ary):
596592
ptr = cgutils.get_item_pointer2(
597593
context, builder, *outputs[i], idxs_bc
598594
)
599-
#store = builder.store(value, ptr)
600-
store = arrayobj.store_item(context, builder, out_types[i], value, ptr)
601-
#store.set_metadata("alias.scope", output_scope_set)
602-
#store.set_metadata("noalias", input_scope_set)
595+
# store = builder.store(value, ptr)
596+
arrayobj.store_item(context, builder, out_types[i], value, ptr)
597+
# store.set_metadata("alias.scope", output_scope_set)
598+
# store.set_metadata("noalias", input_scope_set)
603599

604600
# Close the loops and write accumulator values to the output arrays
605601
for depth, loop in enumerate(loop_stack[::-1]):
@@ -615,12 +611,14 @@ def extract_array(aryty, ary):
615611
context, builder, *outputs[output], idxs_bc
616612
)
617613
load = builder.load(accu)
618-
#load.set_metadata("alias.scope", output_scope_set)
619-
#load.set_metadata("noalias", input_scope_set)
620-
#store = builder.store(load, ptr)
621-
store = arrayobj.store_item(context, builder, out_types[output], load, ptr)
622-
#store.set_metadata("alias.scope", output_scope_set)
623-
#store.set_metadata("noalias", input_scope_set)
614+
# load.set_metadata("alias.scope", output_scope_set)
615+
# load.set_metadata("noalias", input_scope_set)
616+
# store = builder.store(load, ptr)
617+
arrayobj.store_item(
618+
context, builder, out_types[output], load, ptr
619+
)
620+
# store.set_metadata("alias.scope", output_scope_set)
621+
# store.set_metadata("noalias", input_scope_set)
624622
loop.__exit__(None, None, None)
625623
return
626624

@@ -672,6 +670,7 @@ def make_output(iter_shape, bc, dtype):
672670

673671
check_arrays = check_broadcasting
674672
else:
673+
675674
@numba.extending.register_jitable
676675
def make_output(iter_shape, bc, dtype):
677676
return np.empty((), dtype)
@@ -680,7 +679,6 @@ def make_output(iter_shape, bc, dtype):
680679
def check_arrays(a, b, c):
681680
pass
682681

683-
684682
make_outputs = tuple_mapper(make_output)
685683

686684
def impl(*inputs):
@@ -756,10 +754,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
756754
scalar_op_fn,
757755
input_bc_patterns,
758756
output_bc_patterns,
759-
output_dtypes=tuple([
760-
variable.dtype
761-
for variable in node.outputs
762-
]),
757+
output_dtypes=tuple([variable.dtype for variable in node.outputs]),
763758
)
764759

765760
# TODO We should do this in vectorize instead
@@ -771,13 +766,16 @@ def elemwise_inplace(*inputs):
771766
outputs = vectorized(*inputs)
772767
for out_idx, in_idx in literal_unroll(pattern):
773768
inputs[in_idx][...] = outputs[out_idx]
769+
774770
else:
775771
elemwise_inplace = vectorized
776772

777773
if len(node.outputs) == 1:
774+
778775
@numba_njit
779776
def elemwise_wrapper(*inputs):
780777
return elemwise_inplace(*inputs)[0]
778+
781779
else:
782780
elemwise_wrapper = vectorized
783781

0 commit comments

Comments
 (0)