Skip to content

Commit eade612

Browse files
authored
TENSOR: Fix slices ref shen return value isn't scalar or vector. #41 (#50)
Closes #41
1 parent bc83b26 commit eade612

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

pyttb/tensor.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,6 @@ def __getitem__(self, item):
12711271
kpdims = [] # dimensions to keep
12721272
rmdims = [] # dimensions to remove
12731273

1274-
# Determine the new size and what dimensions to keep
12751274
# Determine the new size and what dimensions to keep
12761275
for i in range(0, len(region)):
12771276
if isinstance(region[i], slice):
@@ -1289,19 +1288,11 @@ def __getitem__(self, item):
12891288

12901289
# If the size is zero, then the result is returned as a scalar
12911290
# otherwise, we convert the result to a tensor
1292-
12931291
if newsiz.size == 0:
12941292
a = newdata
12951293
else:
1296-
if rmdims.size == 0:
1297-
a = ttb.tensor.from_data(newdata)
1298-
else:
1299-
# If extracted data is a vector then no need to tranpose it
1300-
if len(newdata.shape) == 1:
1301-
a = ttb.tensor.from_data(newdata)
1302-
else:
1303-
a = ttb.tensor.from_data(np.transpose(newdata, np.concatenate((kpdims, rmdims))))
1304-
return ttb.tt_subsubsref(a, item)
1294+
a = ttb.tensor.from_data(newdata)
1295+
return a
13051296

13061297
# *** CASE 2a: Subscript indexing ***
13071298
if len(item) > 1 and isinstance(item[-1], str) and item[-1] == 'extract':

tests/test_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def test_tensor__getitem__(sample_tensor_2way):
280280
assert tensorInstance[0, 0] == params['data'][0, 0]
281281
# Case 1 Subtensor
282282
assert (tensorInstance[:, :] == tensorInstance).data.all()
283+
three_way_data = np.random.random((2, 3, 4))
284+
two_slices = (slice(None,None,None), 0, slice(None,None,None))
285+
assert (ttb.tensor.from_data(three_way_data)[two_slices].double() == three_way_data[two_slices]).all()
283286
# Case 1 Subtensor
284287
assert (tensorInstance[np.array([0, 1]), :].data == tensorInstance.data[[0, 1], :]).all()
285288
# Case 1 Subtensor

0 commit comments

Comments
 (0)