1
1
# Copyright 2022 National Technology & Engineering Solutions of Sandia,
2
2
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
3
3
# U.S. Government retains certain rights in this software.
4
-
4
+ """PYTTB shared utilities across tensor types"""
5
5
from inspect import signature
6
6
from typing import Optional , Tuple , overload
7
7
8
8
import numpy as np
9
- import scipy .sparse as sparse
10
9
11
10
import pyttb as ttb
12
11
13
12
14
13
def tt_to_dense_matrix (tensorInstance , mode , transpose = False ):
15
14
"""
16
- Helper function to unwrap tensor into dense matrix, should replace the core need for tenmat
15
+ Helper function to unwrap tensor into dense matrix, should replace the core need
16
+ for tenmat
17
17
18
18
Parameters
19
19
----------
@@ -46,7 +46,8 @@ def tt_to_dense_matrix(tensorInstance, mode, transpose=False):
46
46
47
47
def tt_from_dense_matrix (matrix , shape , mode , idx ):
48
48
"""
49
- Helper function to wrap dense matrix into tensor. Inverse of :class:`pyttb.tt_to_dense_matrix`
49
+ Helper function to wrap dense matrix into tensor.
50
+ Inverse of :class:`pyttb.tt_to_dense_matrix`
50
51
51
52
Parameters
52
53
----------
@@ -70,63 +71,6 @@ def tt_from_dense_matrix(matrix, shape, mode, idx):
70
71
return tensorInstance
71
72
72
73
73
- def tt_to_sparse_matrix (sptensorInstance , mode , transpose = False ):
74
- """
75
- Helper function to unwrap sptensor into sparse matrix, should replace the core need for sptenmat
76
-
77
- Parameters
78
- ----------
79
- sptensorInstance: :class:`pyttb.sptensor`
80
- mode: int
81
- Mode around which to unwrap tensor
82
- transpose: bool
83
- Whether or not to tranpose unwrapped tensor
84
-
85
- Returns
86
- -------
87
- spmatrix: :class:`Scipy.sparse.coo_matrix`
88
- """
89
- old = np .setdiff1d (np .arange (sptensorInstance .ndims ), mode ).astype (int )
90
- spmatrix = sptensorInstance .reshape (
91
- (np .prod (np .array (sptensorInstance .shape )[old ]),), old
92
- ).spmatrix ()
93
- if transpose :
94
- return spmatrix .transpose ()
95
- else :
96
- return spmatrix
97
-
98
-
99
- def tt_from_sparse_matrix (spmatrix , shape , mode , idx ):
100
- """
101
- Helper function to wrap sparse matrix into sptensor. Inverse of :class:`pyttb.tt_to_sparse_matrix`
102
-
103
- Parameters
104
- ----------
105
- spmatrix: :class:`Scipy.sparse.coo_matrix`
106
- mode: int
107
- Mode around which tensor was unwrapped
108
- idx: int
109
- in {0,1}, idx of mode in spmatrix, s.b. 0 for tranpose=True
110
-
111
- Returns
112
- -------
113
- sptensorInstance: :class:`pyttb.sptensor`
114
- """
115
- siz = np .array (shape )
116
- old = np .setdiff1d (np .arange (len (shape )), mode ).astype (int )
117
- sptensorInstance = ttb .sptensor .from_tensor_type (sparse .coo_matrix (spmatrix ))
118
-
119
- # This expands the compressed dimension back to full size
120
- sptensorInstance = sptensorInstance .reshape (siz [old ], idx )
121
- # This puts the modes in the right order, reshape places modified modes after the unchanged ones
122
- sptensorInstance = sptensorInstance .reshape (
123
- shape ,
124
- np .concatenate ((np .arange (1 , mode + 1 ), [0 ], np .arange (mode + 1 , len (shape )))),
125
- )
126
-
127
- return sptensorInstance
128
-
129
-
130
74
def tt_union_rows (MatrixA , MatrixB ):
131
75
"""
132
76
Helper function to reproduce functionality of MATLABS intersect(a,b,'rows')
@@ -206,8 +150,8 @@ def tt_dimscheck(
206
150
# Save dimensions of dims
207
151
P = len (dims )
208
152
209
- # Reorder dims from smallest to largest
210
- # (this matters in particular for the vector multiplicand case, where the order affects the result)
153
+ # Reorder dims from smallest to largest (this matters in particular for the vector
154
+ # multiplicand case, where the order affects the result)
211
155
sidx = np .argsort (dims )
212
156
sdims = dims [sidx ]
213
157
vidx = None
@@ -217,24 +161,25 @@ def tt_dimscheck(
217
161
if M > N :
218
162
assert False , "Cannot have more multiplicands than dimensions"
219
163
220
- # Check that the number of multiplicands must either be full dimensional or equal to the specified dimensions
221
- # (M==N) or M(==P) respectively
222
- if M != N and M != P :
164
+ # Check that the number of multiplicands must either be full dimensional or
165
+ # equal to the specified dimensions (M==N) or M(==P) respectively
166
+ if M not in ( N , P ) :
223
167
assert False , "Invalid number of multiplicands"
224
168
225
169
# Check sizes to determine how to index multiplicands
226
170
if P == M :
227
- # Case 1: Number of items in dims and number of multiplicands are equal; therfore, index in order of sdims
171
+ # Case 1: Number of items in dims and number of multiplicands are equal;
172
+ # therfore, index in order of sdims
228
173
vidx = sidx
229
174
else :
230
- # Case 2: Number of multiplicands is equal to the number of dimensions of tensor;
231
- # therefore, index multiplicands by dimensions in dims argument.
175
+ # Case 2: Number of multiplicands is equal to the number of dimensions of
176
+ # tensor; therefore, index multiplicands by dimensions in dims argument.
232
177
vidx = sdims
233
178
234
179
return sdims , vidx
235
180
236
181
237
- def tt_tenfun (function_handle , * inputs ):
182
+ def tt_tenfun (function_handle , * inputs ): # pylint:disable=too-many-branches
238
183
"""
239
184
Apply a function to each element in a tensor
240
185
@@ -256,13 +201,13 @@ def tt_tenfun(function_handle, *inputs):
256
201
assert callable (function_handle ), "function_handle must be callable"
257
202
258
203
# Convert inputs to tensors if they aren't already
259
- for i in range ( 0 , len ( inputs ) ):
260
- if isinstance (inputs [ i ], ttb .tensor ) or isinstance ( inputs [ i ], ( float , int )):
204
+ for i , an_input in enumerate ( inputs ):
205
+ if isinstance (an_input , ( ttb .tensor , float , int )):
261
206
continue
262
- elif isinstance (inputs [ i ] , np .ndarray ):
263
- inputs [i ] = ttb .tensor .from_data (inputs [ i ] )
207
+ if isinstance (an_input , np .ndarray ):
208
+ inputs [i ] = ttb .tensor .from_data (an_input )
264
209
elif isinstance (
265
- inputs [ i ] ,
210
+ an_input ,
266
211
(
267
212
ttb .ktensor ,
268
213
ttb .ttensor ,
@@ -272,11 +217,12 @@ def tt_tenfun(function_handle, *inputs):
272
217
ttb .symktensor ,
273
218
),
274
219
):
275
- inputs [i ] = ttb .tensor .from_tensor_type (inputs [ i ] )
220
+ inputs [i ] = ttb .tensor .from_tensor_type (an_input )
276
221
else :
277
222
assert False , "Invalid input to ten fun"
278
223
279
- # It's ok if there are two input and one is a scalar; otherwise all inputs have to be the same size
224
+ # It's ok if there are two input and one is a scalar; otherwise all inputs have to
225
+ # be the same size
280
226
if (
281
227
(len (inputs ) == 2 )
282
228
and isinstance (inputs [0 ], (float , int ))
@@ -290,15 +236,15 @@ def tt_tenfun(function_handle, *inputs):
290
236
):
291
237
sz = inputs [0 ].shape
292
238
else :
293
- for i in range ( 0 , len ( inputs ) ):
294
- if isinstance (inputs [ i ] , (float , int )):
295
- assert False , "Argument {} is a scalar but expected a tensor" . format ( i )
239
+ for i , an_input in enumerate ( inputs ):
240
+ if isinstance (an_input , (float , int )):
241
+ assert False , f "Argument { i } is a scalar but expected a tensor"
296
242
elif i == 0 :
297
- sz = inputs [ i ] .shape
298
- elif sz != inputs [ i ] .shape :
243
+ sz = an_input .shape
244
+ elif sz != an_input .shape :
299
245
assert (
300
246
False
301
- ), "Tensor {} is not the same size as the first tensor input" . format ( i )
247
+ ), f "Tensor { i } is not the same size as the first tensor input"
302
248
303
249
# Number of inputs for function handle
304
250
nfunin = len (signature (function_handle ).parameters )
@@ -322,8 +268,8 @@ def tt_tenfun(function_handle, *inputs):
322
268
X = np .reshape (X , (1 , - 1 ))
323
269
else :
324
270
X = np .zeros ((len (inputs ), np .prod (sz )))
325
- for i in range ( 0 , len ( inputs ) ):
326
- X [i , :] = np .reshape (inputs [ i ] .data , (np .prod (sz )))
271
+ for i , an_input in enumerate ( inputs ):
272
+ X [i , :] = np .reshape (an_input .data , (np .prod (sz )))
327
273
data = function_handle (X )
328
274
data = np .reshape (data , sz )
329
275
Z = ttb .tensor .from_data (data )
@@ -395,7 +341,7 @@ def tt_intersect_rows(MatrixA, MatrixB):
395
341
return location [np .where (location >= 0 )]
396
342
397
343
398
- def tt_irenumber (t , shape , number_range ):
344
+ def tt_irenumber (t , shape , number_range ): # pylint: disable=unused-argument
399
345
"""
400
346
RENUMBER indices for sptensor subsasgn
401
347
@@ -409,25 +355,25 @@ def tt_irenumber(t, shape, number_range):
409
355
-------
410
356
newsubs: :class:`numpy.ndarray`
411
357
"""
412
- # TODO shape is unused. Should it be used? I don't particularly understand what this is meant to be doing
358
+ # TODO shape is unused. Should it be used? I don't particularly understand what
359
+ # this is meant to be doing
413
360
nz = t .nnz
414
361
if nz == 0 :
415
362
newsubs = np .array ([])
416
363
return newsubs
417
- else :
418
- newsubs = t .subs .astype (int )
419
- for i in range (0 , len (number_range )):
420
- r = number_range [i ]
421
- if isinstance (r , slice ):
422
- newsubs [:, i ] = (newsubs [:, i ])[r ]
423
- elif isinstance (r , int ):
424
- # This appears to be inserting new keys as rows to our subs here
425
- newsubs = np .insert (newsubs , obj = i , values = r , axis = 1 )
426
- else :
427
- if isinstance (r , list ):
428
- r = np .array (r )
429
- newsubs [:, i ] = r [newsubs [:, i ]]
430
- return newsubs
364
+
365
+ newsubs = t .subs .astype (int )
366
+ for i , r in enumerate (number_range ):
367
+ if isinstance (r , slice ):
368
+ newsubs [:, i ] = (newsubs [:, i ])[r ]
369
+ elif isinstance (r , int ):
370
+ # This appears to be inserting new keys as rows to our subs here
371
+ newsubs = np .insert (newsubs , obj = i , values = r , axis = 1 )
372
+ else :
373
+ if isinstance (r , list ):
374
+ r = np .array (r )
375
+ newsubs [:, i ] = r [newsubs [:, i ]]
376
+ return newsubs
431
377
432
378
433
379
def tt_assignment_type (x , subs , rhs ):
@@ -444,13 +390,12 @@ def tt_assignment_type(x, subs, rhs):
444
390
-------
445
391
objectType
446
392
"""
447
- if type (x ) == type (rhs ):
393
+ if type (x ) is type (rhs ):
448
394
return "subtensor"
449
395
# If subscripts is a tuple that contains an nparray
450
- elif isinstance (subs , tuple ) and len (subs ) >= 2 :
396
+ if isinstance (subs , tuple ) and len (subs ) >= 2 :
451
397
return "subtensor"
452
- else :
453
- return "subscripts"
398
+ return "subscripts"
454
399
455
400
456
401
def tt_renumber (subs , shape , number_range ):
@@ -476,8 +421,8 @@ def tt_renumber(subs, shape, number_range):
476
421
"""
477
422
newshape = np .array (shape )
478
423
newsubs = subs
479
- for i in range (0 , len (shape )):
480
- if not ( number_range [i ] == slice (None , None , None ) ):
424
+ for i in range (0 , len (shape )): # pylint: disable=consider-using-enumerate
425
+ if not number_range [i ] == slice (None , None , None ):
481
426
if subs .size == 0 :
482
427
if not isinstance (number_range [i ], slice ):
483
428
if isinstance (number_range [i ], (int , float )):
@@ -529,12 +474,14 @@ def tt_renumberdim(idx, shape, number_range):
529
474
return newidx , newshape
530
475
531
476
477
+ # TODO make more efficient, decide if we want to support the multiple response
478
+ # matlab does
479
+ # pylint: disable=line-too-long
480
+ # https://stackoverflow.com/questions/22699756/python-version-of-ismember-with-rows-and-index
481
+ # For thoughts on how to speed this up
532
482
def tt_ismember_rows (search , source ):
533
483
"""
534
484
Find location of search rows in source array
535
- https://stackoverflow.com/questions/22699756/python-version-of-ismember-with-rows-and-index
536
- For thoughts on how to speed this up
537
- #TODO make more efficient, decide if we want to support the multiple response matlab does
538
485
539
486
Parameters
540
487
----------
@@ -551,10 +498,10 @@ def tt_ismember_rows(search, source):
551
498
Examples
552
499
--------
553
500
>>> a = np.array([[4, 6], [1, 9], [2, 6]])
554
- >>> b = np.array([[1, 7],[1, 8],[ 2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
501
+ >>> b = np.array([[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
555
502
>>> results = tt_ismember_rows(a,b)
556
503
>>> print(results)
557
- [ 5 -1 2 ]
504
+ [ 3 -1 0 ]
558
505
559
506
"""
560
507
results = np .ones (shape = search .shape [0 ]) * - 1
@@ -585,7 +532,7 @@ def tt_ind2sub(shape: Tuple[int, ...], idx: np.ndarray) -> np.ndarray:
585
532
return np .array (np .unravel_index (idx , shape , order = "F" )).transpose ()
586
533
587
534
588
- def tt_subsubsref (obj , s ):
535
+ def tt_subsubsref (obj , s ): # pylint: disable=unused-argument
589
536
"""
590
537
Helper function for tensor toolbox subsref.
591
538
@@ -598,7 +545,8 @@ def tt_subsubsref(obj, s):
598
545
-------
599
546
Still uncertain to this functionality
600
547
"""
601
- # TODO figure out when subsref yields key of length>1 for now ignore this logic and just return
548
+ # TODO figure out when subsref yields key of length>1 for now ignore this logic and
549
+ # just return
602
550
# if len(s) == 1:
603
551
# return obj
604
552
# else:
@@ -608,7 +556,8 @@ def tt_subsubsref(obj, s):
608
556
609
557
def tt_intvec2str (v ):
610
558
"""
611
- Print integer vector to a string with brackets. Numpy should already handle this so it is a placeholder stub
559
+ Print integer vector to a string with brackets. Numpy should already handle this so
560
+ it is a placeholder stub
612
561
613
562
Parameters
614
563
----------
@@ -774,10 +723,7 @@ def isrow(v):
774
723
-------
775
724
bool
776
725
"""
777
- if v .ndim == 2 and v .shape [0 ] == 1 and v .shape [1 ] >= 1 :
778
- return True
779
- else :
780
- return False
726
+ return v .ndim == 2 and v .shape [0 ] == 1 and v .shape [1 ] >= 1
781
727
782
728
783
729
def isvector (a ):
@@ -794,13 +740,11 @@ def isvector(a):
794
740
-------
795
741
bool
796
742
"""
797
- if a .ndim == 1 or (a .ndim == 2 and (a .shape [0 ] == 1 or a .shape [1 ] == 1 )):
798
- return True
799
- else :
800
- return False
743
+ return a .ndim == 1 or (a .ndim == 2 and (a .shape [0 ] == 1 or a .shape [1 ] == 1 ))
801
744
802
745
803
- # TODO: this is a challenge, since it may need to apply to either Python built in types or numpy types
746
+ # TODO: this is a challenge, since it may need to apply to either Python built in types
747
+ # or numpy types
804
748
def islogical (a ):
805
749
"""
806
750
ISLOGICAL Checks if vector is a logical vector.
@@ -815,4 +759,4 @@ def islogical(a):
815
759
-------
816
760
bool
817
761
"""
818
- return type ( a ) == bool
762
+ return isinstance ( a , bool )
0 commit comments