Skip to content

Commit 982e4c4

Browse files
authored
replaces numpy sqrt method with pytensor equivalent (#6405)
1 parent f231d13 commit 982e4c4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pymc/distributions/transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ def log_jac_det(self, value, *rv_inputs):
333333
def extend_axis(array, axis):
334334
n = array.shape[axis] + 1
335335
sum_vals = array.sum(axis, keepdims=True)
336-
norm = sum_vals / (np.sqrt(n) + n)
337-
fill_val = norm - sum_vals / np.sqrt(n)
336+
norm = sum_vals / (at.sqrt(n) + n)
337+
fill_val = norm - sum_vals / at.sqrt(n)
338338

339339
out = at.concatenate([array, fill_val], axis=axis)
340340
return out - norm
@@ -346,8 +346,8 @@ def extend_axis_rev(array, axis):
346346
n = array.shape[normalized_axis]
347347
last = at.take(array, [-1], axis=normalized_axis)
348348

349-
sum_vals = -last * np.sqrt(n)
350-
norm = sum_vals / (np.sqrt(n) + n)
349+
sum_vals = -last * at.sqrt(n)
350+
norm = sum_vals / (at.sqrt(n) + n)
351351
slice_before = (slice(None, None),) * normalized_axis
352352

353353
return array[slice_before + (slice(None, -1),)] + norm

0 commit comments

Comments
 (0)