Skip to content

Commit 2ab1934

Browse files
authored
Allowing rdims or cdims to be empty array. (#43)
Closes #42
1 parent e296f3a commit 2ab1934

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

pyttb/tenmat.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,19 @@ def from_tensor_type(cls, source, rdims=None, cdims=None, cdims_cyclic=None):
113113
elif rdims is None and cdims is not None:
114114
rdims = np.setdiff1d(alldims, cdims)
115115

116-
117-
dims = np.hstack([rdims, cdims])
116+
# if rdims or cdims is empty, hstack will output an array of float not int
117+
if rdims.size == 0:
118+
dims = cdims.copy()
119+
elif cdims.size == 0:
120+
dims = rdims.copy()
121+
else:
122+
dims = np.hstack([rdims, cdims])
118123
if not len(dims) == n or not (alldims == np.sort(dims)).all():
119124
assert False, 'Incorrect specification of dimensions, the sorted concatenation of rdims and cdims must be range(source.ndims).'
120125

121-
data = np.reshape(source.permute(dims).data, (np.prod(np.array(tshape)[rdims]), np.prod(np.array(tshape)[cdims])), order='F')
126+
rprod = 1 if rdims.size == 0 else np.prod(np.array(tshape)[rdims])
127+
cprod = 1 if cdims.size == 0 else np.prod(np.array(tshape)[cdims])
128+
data = np.reshape(source.permute(dims).data, (rprod, cprod), order='F')
122129

123130
# Create tenmat
124131
tenmatInstance = cls()

tests/test_tenmat.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def sample_ndarray_2way():
2222
params = {'data':ndarrayInstance, 'shape':shape}
2323
return params, ndarrayInstance
2424

25+
@pytest.fixture()
26+
def sample_tensor_3way():
27+
data = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
28+
shape = (2, 3, 2)
29+
params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape}
30+
tensorInstance = ttb.tensor().from_data(data, shape)
31+
return params, tensorInstance
32+
2533
@pytest.fixture()
2634
def sample_ndarray_4way():
2735
shape = (2, 2, 2, 2)
@@ -184,8 +192,9 @@ def test_tenmat_initialization_from_data(sample_ndarray_1way, sample_ndarray_2wa
184192
assert exc in str(excinfo)
185193

186194
@pytest.mark.indevelopment
187-
def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tensor_4way):
195+
def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tensor_3way, sample_tensor_4way):
188196
(_, tensorInstance) = sample_tensor_4way
197+
(_, tensorInstance3) = sample_tensor_3way
189198
(params, tenmatInstance) = sample_tenmat_4way
190199
tshape = params['tshape']
191200
rdims = params['rdims']
@@ -208,6 +217,11 @@ def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tenso
208217
assert tenmatInstance.shape == tenmatTensorRdims.shape
209218
assert tenmatInstance.tshape == tenmatTensorRdims.tshape
210219

220+
# Constructor from tensor using empty rdims
221+
tenmatTensorRdims = ttb.tenmat.from_tensor_type(tensorInstance3, rdims=np.array([]))
222+
data = np.reshape(np.arange(1,13),(1,12))
223+
assert (tenmatTensorRdims.data == data).all()
224+
211225
# Constructor from tensor using cdims only
212226
tenmatTensorCdims = ttb.tenmat.from_tensor_type(tensorInstance, cdims=cdims)
213227
assert (tenmatInstance.data == tenmatTensorCdims.data).all()
@@ -216,6 +230,11 @@ def test_tenmat_initialization_from_tensor_type(sample_tenmat_4way, sample_tenso
216230
assert tenmatInstance.shape == tenmatTensorCdims.shape
217231
assert tenmatInstance.tshape == tenmatTensorCdims.tshape
218232

233+
# Constructor from tensor using empty cdims
234+
tenmatTensorCdims = ttb.tenmat.from_tensor_type(tensorInstance3, cdims=np.array([]))
235+
data = np.reshape(np.arange(1,13),(12,1))
236+
assert (tenmatTensorCdims.data == data).all()
237+
219238
# Constructor from tensor using rdims and cdims
220239
tenmatTensorRdimsCdims = ttb.tenmat.from_tensor_type(tensorInstance, rdims=rdims, cdims=cdims)
221240
assert (tenmatInstance.data == tenmatTensorRdimsCdims.data).all()

0 commit comments

Comments
 (0)