Skip to content

Commit 02ea662

Browse files
committed
Aliasing in elemwise and some fixes
1 parent 93a02c0 commit 02ea662

File tree

1 file changed

+89
-42
lines changed

1 file changed

+89
-42
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,13 @@ def _vectorize_bc(
447447
noalias_outputs=False,
448448
):
449449

450-
flags = {
450+
flags = True
451+
{
451452
"arcp", # Allow Reciprocal
452453
"contract", # Allow floating-point contraction
453454
"afn", # Approximate functions
454455
"reassoc",
456+
"nsz", # TODO Do we want this one?
455457
}
456458

457459
n_inputs = len(input_bc_patterns)
@@ -473,6 +475,9 @@ def loop_call(typingctx, *args):
473475
sig = types.void(types.StarArgTuple([*out_types, *in_types, iter_shape_type]))
474476

475477
def codegen(context, builder, signature, args):
478+
for i in [0]:
479+
arg = builder.function.args[i]
480+
arg.add_attribute("noalias")
476481
safe = (boundscheck, False)
477482
[args] = args
478483
args = cgutils.unpack_tuple(builder, args)
@@ -485,9 +490,10 @@ def codegen(context, builder, signature, args):
485490

486491
# Lower the code of the scalar function so that we can use it in the inner loop
487492
# Caching is set to false to avoid a numba bug TODO ref?
488-
inner = context.compile_subroutine(
493+
inner_func = context.compile_subroutine(
489494
builder, scalar_func, scalar_signature, caching=False,
490-
).fndesc
495+
)
496+
inner = inner_func.fndesc
491497

492498
# Extract shape and stride information from the array.
493499
# For later use in the loop body to do the indexing
@@ -499,13 +505,15 @@ def extract_array(aryty, ary):
499505
layout = aryty.layout
500506
return (data, shape, strides, layout)
501507

502-
mod = builder.module
503-
domain = mod.add_metadata([], self_ref=True)
504-
input_scope = mod.add_metadata([domain], self_ref=True)
505-
output_scope = mod.add_metadata([domain], self_ref=True)
506-
input_scope_set = mod.add_metadata([input_scope, output_scope])
507-
508-
output_scope_set = mod.add_metadata([input_scope, output_scope])
508+
# TODO I think this is better than the noalias attribute
509+
# for the input, but self_ref isn't supported in a released
510+
# llvmlite version yet
511+
#mod = builder.module
512+
#domain = mod.add_metadata([], self_ref=True)
513+
#input_scope = mod.add_metadata([domain], self_ref=True)
514+
#output_scope = mod.add_metadata([domain], self_ref=True)
515+
#input_scope_set = mod.add_metadata([input_scope, output_scope])
516+
#output_scope_set = mod.add_metadata([input_scope, output_scope])
509517

510518
inputs = [
511519
extract_array(aryty, ary)
@@ -551,8 +559,8 @@ def extract_array(aryty, ary):
551559
context, builder, *array_info, idxs_bc, *safe
552560
)
553561
val = builder.load(ptr)
554-
val.set_metadata("alias.scope", input_scope_set)
555-
val.set_metadata("noalias", output_scope_set)
562+
#val.set_metadata("alias.scope", input_scope_set)
563+
#val.set_metadata("noalias", output_scope_set)
556564
input_vals.append(val)
557565

558566
# Call scalar function
@@ -572,8 +580,14 @@ def extract_array(aryty, ary):
572580
zip(output_accumulator, output_values, strict=True)
573581
):
574582
if accu is not None:
575-
new_value = builder.fadd(builder.load(accu), value)
576-
builder.store(new_value, accu)
583+
load = builder.load(accu)
584+
#load.set_metadata("alias.scope", output_scope_set)
585+
#load.set_metadata("noalias", input_scope_set)
586+
new_value = builder.fadd(load, value)
587+
store = builder.store(new_value, accu)
588+
# TODO ?
589+
#store.set_metadata("alias.scope", output_scope_set)
590+
#store.set_metadata("noalias", input_scope_set)
577591
else:
578592
idxs_bc = [
579593
zero if bc else idx
@@ -582,9 +596,10 @@ def extract_array(aryty, ary):
582596
ptr = cgutils.get_item_pointer2(
583597
context, builder, *outputs[i], idxs_bc
584598
)
585-
store = builder.store(value, ptr)
586-
store.set_metadata("alias.scope", output_scope_set)
587-
store.set_metadata("noalias", input_scope_set)
599+
#store = builder.store(value, ptr)
600+
store = arrayobj.store_item(context, builder, out_types[i], value, ptr)
601+
#store.set_metadata("alias.scope", output_scope_set)
602+
#store.set_metadata("noalias", input_scope_set)
588603

589604
# Close the loops and write accumulator values to the output arrays
590605
for depth, loop in enumerate(loop_stack[::-1]):
@@ -599,16 +614,20 @@ def extract_array(aryty, ary):
599614
ptr = cgutils.get_item_pointer2(
600615
context, builder, *outputs[output], idxs_bc
601616
)
602-
store = builder.store(builder.load(accu), ptr)
603-
store.set_metadata("alias.scope", output_scope_set)
604-
store.set_metadata("noalias", input_scope_set)
617+
load = builder.load(accu)
618+
#load.set_metadata("alias.scope", output_scope_set)
619+
#load.set_metadata("noalias", input_scope_set)
620+
#store = builder.store(load, ptr)
621+
store = arrayobj.store_item(context, builder, out_types[output], load, ptr)
622+
#store.set_metadata("alias.scope", output_scope_set)
623+
#store.set_metadata("noalias", input_scope_set)
605624
loop.__exit__(None, None, None)
606625
return
607626

608627
return sig, codegen
609628

610629
def vectorized(*inputs):
611-
pass
630+
raise NotImplementedError()
612631

613632
@numba.extending.overload(vectorized, jit_options={"fastmath": flags})
614633
def impl_vectorized(*inputs):
@@ -635,17 +654,32 @@ def impl_vectorized(*inputs):
635654

636655
iter_shape_repeated = tuple([iter_shape_template[:] for _ in range(n_outputs)])
637656

638-
@numba.extending.register_jitable
639-
def make_output(iter_shape, bc, dtype):
640-
shape = iter_shape
641-
for i in range(ndim):
642-
if bc[i]:
643-
shape = tuple_setitem(
644-
shape,
645-
i,
646-
1,
647-
)
648-
return np.empty(shape, dtype)
657+
ndim_range = tuple(range(ndim))
658+
659+
if ndim > 0:
660+
# TODO workaround for https://github.com/numba/numba/issues/8654
661+
@numba.extending.register_jitable
662+
def make_output(iter_shape, bc, dtype):
663+
shape = iter_shape
664+
for i in literal_unroll(ndim_range):
665+
if bc[i]:
666+
shape = tuple_setitem(
667+
shape,
668+
i,
669+
1,
670+
)
671+
return np.empty(shape, dtype)
672+
673+
check_arrays = check_broadcasting
674+
else:
675+
@numba.extending.register_jitable
676+
def make_output(iter_shape, bc, dtype):
677+
return np.empty((), dtype)
678+
679+
@numba.extending.register_jitable
680+
def check_arrays(a, b, c):
681+
pass
682+
649683

650684
make_outputs = tuple_mapper(make_output)
651685

@@ -667,8 +701,6 @@ def impl(*inputs):
667701
)
668702

669703
outputs = make_outputs(iter_shape_rep, output_bc_patterns, output_dtypes)
670-
#outputs = (np.empty(inputs[0].shape),)
671-
#iter_shape = inputs[0].shape
672704

673705
i = 0
674706
for input_ in literal_unroll(inputs):
@@ -704,21 +736,24 @@ def numba_funcify_Elemwise(op, node, **kwargs):
704736
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
705737
scalar_node = op.scalar_op.make_node(*scalar_inputs)
706738

739+
flags = True
740+
{
741+
"arcp", # Allow Reciprocal
742+
"contract", # Allow floating-point contraction
743+
"afn", # Approximate functions
744+
"reassoc",
745+
}
746+
707747
scalar_op_fn = numba_funcify(
708-
op.scalar_op, node=scalar_node, parent_node=node, **kwargs
748+
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
709749
)
710750

711-
assert not op.inplace_pattern
712-
713-
#scalar_wrapper = register_jitable(scalar_op_fn)
714-
scalar_wrapper = scalar_op_fn
715-
716751
ndim = node.outputs[0].ndim
717752
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
718753
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
719754

720755
vectorized = _vectorize_bc(
721-
scalar_wrapper,
756+
scalar_op_fn,
722757
input_bc_patterns,
723758
output_bc_patterns,
724759
output_dtypes=tuple([
@@ -727,10 +762,22 @@ def numba_funcify_Elemwise(op, node, **kwargs):
727762
]),
728763
)
729764

765+
# TODO We should do this in vectorize instead
766+
if op.inplace_pattern:
767+
pattern = list(op.inplace_pattern.items())
768+
769+
@numba_njit
770+
def elemwise_inplace(*inputs):
771+
outputs = vectorized(*inputs)
772+
for out_idx, in_idx in literal_unroll(pattern):
773+
inputs[in_idx][...] = outputs[out_idx]
774+
else:
775+
elemwise_inplace = vectorized
776+
730777
if len(node.outputs) == 1:
731778
@numba_njit
732779
def elemwise_wrapper(*inputs):
733-
return vectorized(*inputs)[0]
780+
return elemwise_inplace(*inputs)[0]
734781
else:
735782
elemwise_wrapper = vectorized
736783

0 commit comments

Comments
 (0)