Skip to content

Commit bdb0ba8

Browse files
authored
Lint pyttb_utils and lint/type sptensor (#77)
* PYTTB_UTILS: Fix and enforce pylint * PYTTB_UTILS: Pull out utility only used internally in sptensor * SPTENSOR: Fix and enforce pylint * SPTENSOR: Initial pass a typing support * SPTENSOR: Complete initial typing coverage * SPTENSOR: Fix test coverage from typing changes. * PYLINT: Update test to lint files in parallel to improve dev experience.
1 parent 8776944 commit bdb0ba8

File tree

6 files changed

+562
-493
lines changed

6 files changed

+562
-493
lines changed

pyttb/pyttb_utils.py

Lines changed: 69 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# Copyright 2022 National Technology & Engineering Solutions of Sandia,
22
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
33
# U.S. Government retains certain rights in this software.
4-
4+
"""PYTTB shared utilities across tensor types"""
55
from inspect import signature
66
from typing import Optional, Tuple, overload
77

88
import numpy as np
9-
import scipy.sparse as sparse
109

1110
import pyttb as ttb
1211

1312

1413
def tt_to_dense_matrix(tensorInstance, mode, transpose=False):
1514
"""
16-
Helper function to unwrap tensor into dense matrix, should replace the core need for tenmat
15+
Helper function to unwrap tensor into dense matrix, should replace the core need
16+
for tenmat
1717
1818
Parameters
1919
----------
@@ -46,7 +46,8 @@ def tt_to_dense_matrix(tensorInstance, mode, transpose=False):
4646

4747
def tt_from_dense_matrix(matrix, shape, mode, idx):
4848
"""
49-
Helper function to wrap dense matrix into tensor. Inverse of :class:`pyttb.tt_to_dense_matrix`
49+
Helper function to wrap dense matrix into tensor.
50+
Inverse of :class:`pyttb.tt_to_dense_matrix`
5051
5152
Parameters
5253
----------
@@ -70,63 +71,6 @@ def tt_from_dense_matrix(matrix, shape, mode, idx):
7071
return tensorInstance
7172

7273

73-
def tt_to_sparse_matrix(sptensorInstance, mode, transpose=False):
74-
"""
75-
Helper function to unwrap sptensor into sparse matrix, should replace the core need for sptenmat
76-
77-
Parameters
78-
----------
79-
sptensorInstance: :class:`pyttb.sptensor`
80-
mode: int
81-
Mode around which to unwrap tensor
82-
transpose: bool
83-
Whether or not to tranpose unwrapped tensor
84-
85-
Returns
86-
-------
87-
spmatrix: :class:`Scipy.sparse.coo_matrix`
88-
"""
89-
old = np.setdiff1d(np.arange(sptensorInstance.ndims), mode).astype(int)
90-
spmatrix = sptensorInstance.reshape(
91-
(np.prod(np.array(sptensorInstance.shape)[old]),), old
92-
).spmatrix()
93-
if transpose:
94-
return spmatrix.transpose()
95-
else:
96-
return spmatrix
97-
98-
99-
def tt_from_sparse_matrix(spmatrix, shape, mode, idx):
100-
"""
101-
Helper function to wrap sparse matrix into sptensor. Inverse of :class:`pyttb.tt_to_sparse_matrix`
102-
103-
Parameters
104-
----------
105-
spmatrix: :class:`Scipy.sparse.coo_matrix`
106-
mode: int
107-
Mode around which tensor was unwrapped
108-
idx: int
109-
in {0,1}, idx of mode in spmatrix, s.b. 0 for tranpose=True
110-
111-
Returns
112-
-------
113-
sptensorInstance: :class:`pyttb.sptensor`
114-
"""
115-
siz = np.array(shape)
116-
old = np.setdiff1d(np.arange(len(shape)), mode).astype(int)
117-
sptensorInstance = ttb.sptensor.from_tensor_type(sparse.coo_matrix(spmatrix))
118-
119-
# This expands the compressed dimension back to full size
120-
sptensorInstance = sptensorInstance.reshape(siz[old], idx)
121-
# This puts the modes in the right order, reshape places modified modes after the unchanged ones
122-
sptensorInstance = sptensorInstance.reshape(
123-
shape,
124-
np.concatenate((np.arange(1, mode + 1), [0], np.arange(mode + 1, len(shape)))),
125-
)
126-
127-
return sptensorInstance
128-
129-
13074
def tt_union_rows(MatrixA, MatrixB):
13175
"""
13276
Helper function to reproduce functionality of MATLABS intersect(a,b,'rows')
@@ -206,8 +150,8 @@ def tt_dimscheck(
206150
# Save dimensions of dims
207151
P = len(dims)
208152

209-
# Reorder dims from smallest to largest
210-
# (this matters in particular for the vector multiplicand case, where the order affects the result)
153+
# Reorder dims from smallest to largest (this matters in particular for the vector
154+
# multiplicand case, where the order affects the result)
211155
sidx = np.argsort(dims)
212156
sdims = dims[sidx]
213157
vidx = None
@@ -217,24 +161,25 @@ def tt_dimscheck(
217161
if M > N:
218162
assert False, "Cannot have more multiplicands than dimensions"
219163

220-
# Check that the number of multiplicands must either be full dimensional or equal to the specified dimensions
221-
# (M==N) or M(==P) respectively
222-
if M != N and M != P:
164+
# Check that the number of multiplicands must either be full dimensional or
165+
# equal to the specified dimensions (M==N) or M(==P) respectively
166+
if M not in (N, P):
223167
assert False, "Invalid number of multiplicands"
224168

225169
# Check sizes to determine how to index multiplicands
226170
if P == M:
227-
# Case 1: Number of items in dims and number of multiplicands are equal; therfore, index in order of sdims
171+
# Case 1: Number of items in dims and number of multiplicands are equal;
172+
# therfore, index in order of sdims
228173
vidx = sidx
229174
else:
230-
# Case 2: Number of multiplicands is equal to the number of dimensions of tensor;
231-
# therefore, index multiplicands by dimensions in dims argument.
175+
# Case 2: Number of multiplicands is equal to the number of dimensions of
176+
# tensor; therefore, index multiplicands by dimensions in dims argument.
232177
vidx = sdims
233178

234179
return sdims, vidx
235180

236181

237-
def tt_tenfun(function_handle, *inputs):
182+
def tt_tenfun(function_handle, *inputs): # pylint:disable=too-many-branches
238183
"""
239184
Apply a function to each element in a tensor
240185
@@ -256,13 +201,13 @@ def tt_tenfun(function_handle, *inputs):
256201
assert callable(function_handle), "function_handle must be callable"
257202

258203
# Convert inputs to tensors if they aren't already
259-
for i in range(0, len(inputs)):
260-
if isinstance(inputs[i], ttb.tensor) or isinstance(inputs[i], (float, int)):
204+
for i, an_input in enumerate(inputs):
205+
if isinstance(an_input, (ttb.tensor, float, int)):
261206
continue
262-
elif isinstance(inputs[i], np.ndarray):
263-
inputs[i] = ttb.tensor.from_data(inputs[i])
207+
if isinstance(an_input, np.ndarray):
208+
inputs[i] = ttb.tensor.from_data(an_input)
264209
elif isinstance(
265-
inputs[i],
210+
an_input,
266211
(
267212
ttb.ktensor,
268213
ttb.ttensor,
@@ -272,11 +217,12 @@ def tt_tenfun(function_handle, *inputs):
272217
ttb.symktensor,
273218
),
274219
):
275-
inputs[i] = ttb.tensor.from_tensor_type(inputs[i])
220+
inputs[i] = ttb.tensor.from_tensor_type(an_input)
276221
else:
277222
assert False, "Invalid input to ten fun"
278223

279-
# It's ok if there are two input and one is a scalar; otherwise all inputs have to be the same size
224+
# It's ok if there are two input and one is a scalar; otherwise all inputs have to
225+
# be the same size
280226
if (
281227
(len(inputs) == 2)
282228
and isinstance(inputs[0], (float, int))
@@ -290,15 +236,15 @@ def tt_tenfun(function_handle, *inputs):
290236
):
291237
sz = inputs[0].shape
292238
else:
293-
for i in range(0, len(inputs)):
294-
if isinstance(inputs[i], (float, int)):
295-
assert False, "Argument {} is a scalar but expected a tensor".format(i)
239+
for i, an_input in enumerate(inputs):
240+
if isinstance(an_input, (float, int)):
241+
assert False, f"Argument {i} is a scalar but expected a tensor"
296242
elif i == 0:
297-
sz = inputs[i].shape
298-
elif sz != inputs[i].shape:
243+
sz = an_input.shape
244+
elif sz != an_input.shape:
299245
assert (
300246
False
301-
), "Tensor {} is not the same size as the first tensor input".format(i)
247+
), f"Tensor {i} is not the same size as the first tensor input"
302248

303249
# Number of inputs for function handle
304250
nfunin = len(signature(function_handle).parameters)
@@ -322,8 +268,8 @@ def tt_tenfun(function_handle, *inputs):
322268
X = np.reshape(X, (1, -1))
323269
else:
324270
X = np.zeros((len(inputs), np.prod(sz)))
325-
for i in range(0, len(inputs)):
326-
X[i, :] = np.reshape(inputs[i].data, (np.prod(sz)))
271+
for i, an_input in enumerate(inputs):
272+
X[i, :] = np.reshape(an_input.data, (np.prod(sz)))
327273
data = function_handle(X)
328274
data = np.reshape(data, sz)
329275
Z = ttb.tensor.from_data(data)
@@ -395,7 +341,7 @@ def tt_intersect_rows(MatrixA, MatrixB):
395341
return location[np.where(location >= 0)]
396342

397343

398-
def tt_irenumber(t, shape, number_range):
344+
def tt_irenumber(t, shape, number_range): # pylint: disable=unused-argument
399345
"""
400346
RENUMBER indices for sptensor subsasgn
401347
@@ -409,25 +355,25 @@ def tt_irenumber(t, shape, number_range):
409355
-------
410356
newsubs: :class:`numpy.ndarray`
411357
"""
412-
# TODO shape is unused. Should it be used? I don't particularly understand what this is meant to be doing
358+
# TODO shape is unused. Should it be used? I don't particularly understand what
359+
# this is meant to be doing
413360
nz = t.nnz
414361
if nz == 0:
415362
newsubs = np.array([])
416363
return newsubs
417-
else:
418-
newsubs = t.subs.astype(int)
419-
for i in range(0, len(number_range)):
420-
r = number_range[i]
421-
if isinstance(r, slice):
422-
newsubs[:, i] = (newsubs[:, i])[r]
423-
elif isinstance(r, int):
424-
# This appears to be inserting new keys as rows to our subs here
425-
newsubs = np.insert(newsubs, obj=i, values=r, axis=1)
426-
else:
427-
if isinstance(r, list):
428-
r = np.array(r)
429-
newsubs[:, i] = r[newsubs[:, i]]
430-
return newsubs
364+
365+
newsubs = t.subs.astype(int)
366+
for i, r in enumerate(number_range):
367+
if isinstance(r, slice):
368+
newsubs[:, i] = (newsubs[:, i])[r]
369+
elif isinstance(r, int):
370+
# This appears to be inserting new keys as rows to our subs here
371+
newsubs = np.insert(newsubs, obj=i, values=r, axis=1)
372+
else:
373+
if isinstance(r, list):
374+
r = np.array(r)
375+
newsubs[:, i] = r[newsubs[:, i]]
376+
return newsubs
431377

432378

433379
def tt_assignment_type(x, subs, rhs):
@@ -444,13 +390,12 @@ def tt_assignment_type(x, subs, rhs):
444390
-------
445391
objectType
446392
"""
447-
if type(x) == type(rhs):
393+
if type(x) is type(rhs):
448394
return "subtensor"
449395
# If subscripts is a tuple that contains an nparray
450-
elif isinstance(subs, tuple) and len(subs) >= 2:
396+
if isinstance(subs, tuple) and len(subs) >= 2:
451397
return "subtensor"
452-
else:
453-
return "subscripts"
398+
return "subscripts"
454399

455400

456401
def tt_renumber(subs, shape, number_range):
@@ -476,8 +421,8 @@ def tt_renumber(subs, shape, number_range):
476421
"""
477422
newshape = np.array(shape)
478423
newsubs = subs
479-
for i in range(0, len(shape)):
480-
if not (number_range[i] == slice(None, None, None)):
424+
for i in range(0, len(shape)): # pylint: disable=consider-using-enumerate
425+
if not number_range[i] == slice(None, None, None):
481426
if subs.size == 0:
482427
if not isinstance(number_range[i], slice):
483428
if isinstance(number_range[i], (int, float)):
@@ -529,12 +474,14 @@ def tt_renumberdim(idx, shape, number_range):
529474
return newidx, newshape
530475

531476

477+
# TODO make more efficient, decide if we want to support the multiple response
478+
# matlab does
479+
# pylint: disable=line-too-long
480+
# https://stackoverflow.com/questions/22699756/python-version-of-ismember-with-rows-and-index
481+
# For thoughts on how to speed this up
532482
def tt_ismember_rows(search, source):
533483
"""
534484
Find location of search rows in source array
535-
https://stackoverflow.com/questions/22699756/python-version-of-ismember-with-rows-and-index
536-
For thoughts on how to speed this up
537-
#TODO make more efficient, decide if we want to support the multiple response matlab does
538485
539486
Parameters
540487
----------
@@ -551,10 +498,10 @@ def tt_ismember_rows(search, source):
551498
Examples
552499
--------
553500
>>> a = np.array([[4, 6], [1, 9], [2, 6]])
554-
>>> b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
501+
>>> b = np.array([[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
555502
>>> results = tt_ismember_rows(a,b)
556503
>>> print(results)
557-
[ 5 -1 2]
504+
[ 3 -1 0]
558505
559506
"""
560507
results = np.ones(shape=search.shape[0]) * -1
@@ -585,7 +532,7 @@ def tt_ind2sub(shape: Tuple[int, ...], idx: np.ndarray) -> np.ndarray:
585532
return np.array(np.unravel_index(idx, shape, order="F")).transpose()
586533

587534

588-
def tt_subsubsref(obj, s):
535+
def tt_subsubsref(obj, s): # pylint: disable=unused-argument
589536
"""
590537
Helper function for tensor toolbox subsref.
591538
@@ -598,7 +545,8 @@ def tt_subsubsref(obj, s):
598545
-------
599546
Still uncertain to this functionality
600547
"""
601-
# TODO figure out when subsref yields key of length>1 for now ignore this logic and just return
548+
# TODO figure out when subsref yields key of length>1 for now ignore this logic and
549+
# just return
602550
# if len(s) == 1:
603551
# return obj
604552
# else:
@@ -608,7 +556,8 @@ def tt_subsubsref(obj, s):
608556

609557
def tt_intvec2str(v):
610558
"""
611-
Print integer vector to a string with brackets. Numpy should already handle this so it is a placeholder stub
559+
Print integer vector to a string with brackets. Numpy should already handle this so
560+
it is a placeholder stub
612561
613562
Parameters
614563
----------
@@ -774,10 +723,7 @@ def isrow(v):
774723
-------
775724
bool
776725
"""
777-
if v.ndim == 2 and v.shape[0] == 1 and v.shape[1] >= 1:
778-
return True
779-
else:
780-
return False
726+
return v.ndim == 2 and v.shape[0] == 1 and v.shape[1] >= 1
781727

782728

783729
def isvector(a):
@@ -794,13 +740,11 @@ def isvector(a):
794740
-------
795741
bool
796742
"""
797-
if a.ndim == 1 or (a.ndim == 2 and (a.shape[0] == 1 or a.shape[1] == 1)):
798-
return True
799-
else:
800-
return False
743+
return a.ndim == 1 or (a.ndim == 2 and (a.shape[0] == 1 or a.shape[1] == 1))
801744

802745

803-
# TODO: this is a challenge, since it may need to apply to either Python built in types or numpy types
746+
# TODO: this is a challenge, since it may need to apply to either Python built in types
747+
# or numpy types
804748
def islogical(a):
805749
"""
806750
ISLOGICAL Checks if vector is a logical vector.
@@ -815,4 +759,4 @@ def islogical(a):
815759
-------
816760
bool
817761
"""
818-
return type(a) == bool
762+
return isinstance(a, bool)

0 commit comments

Comments
 (0)