Skip to content

Commit dedb31f

Browse files
committed
Fix numpy DeprecationWarning when converting integers to PyTensor Constants
1 parent 0b632bd commit dedb31f

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

pytensor/misc/safe_asarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _asarray(a, dtype, order=None):
3232
if str(dtype) == "floatX":
3333
dtype = config.floatX
3434
dtype = np.dtype(dtype) # Convert into dtype object.
35-
rval = np.asarray(a, dtype=dtype, order=order)
35+
rval = np.asarray(a, order=order).astype(dtype)
3636
# Note that dtype comparison must be done by comparing their `num`
3737
# attribute. One cannot assume that two identical data types are pointers
3838
# towards the same object (e.g. under Windows this appears not to be the

tests/tensor/test_basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import warnings
23
from functools import partial
34
from tempfile import mkstemp
45

@@ -3171,6 +3172,7 @@ def ok(z, floatX):
31713172
ok(np.float64(x), floatX)
31723173

31733174

3175+
@pytest.mark.filterwarnings("error")
31743176
class TestLongTensor:
31753177
def test_fit_int64(self):
31763178
bitwidth = PYTHON_INT_BITWIDTH

0 commit comments

Comments
 (0)