Skip to content

Commit 961620c

Browse files
authored
Add test that matches TTB for MATLAB output of HOSVD (#79)
This closes #78
1 parent 171aeb0 commit 961620c

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/test_hosvd.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ def sample_tensor():
1313
return params, tensorInstance
1414

1515

16+
@pytest.fixture()
17+
def sample_tensor_3way():
18+
shape = (3, 3, 3)
19+
data = np.array(range(1, 28)).reshape(shape, order="F")
20+
params = {"data": data, "shape": shape}
21+
tensorInstance = ttb.tensor().from_data(data, shape)
22+
return params, tensorInstance
23+
24+
1625
@pytest.mark.indevelopment
1726
def test_hosvd_simple_convergence(capsys, sample_tensor):
1827
(data, T) = sample_tensor
@@ -65,3 +74,48 @@ def test_hosvd_incorrect_dimorder(capsys, sample_tensor):
6574
dimorder = 1
6675
with pytest.raises(ValueError):
6776
_ = ttb.hosvd(T, 1, dimorder=dimorder)
77+
78+
79+
@pytest.mark.indevelopment
80+
def test_hosvd_3way(capsys, sample_tensor_3way):
81+
(data, T) = sample_tensor_3way
82+
M = ttb.hosvd(T, 1e-4, verbosity=0)
83+
capsys.readouterr()
84+
print(f"M=\n{M}")
85+
core = np.array(
86+
[
87+
[
88+
[-8.301598119750199e01, -5.005881796972034e-03],
89+
[-1.268039597172832e-02, 5.842630378620833e00],
90+
],
91+
[
92+
[3.709974006281391e-02, -1.915213813096568e00],
93+
[-5.157111619887230e-01, 5.243776123493664e-01],
94+
],
95+
]
96+
)
97+
fm0 = np.array(
98+
[
99+
[-5.452132631706279e-01, -7.321719955012304e-01],
100+
[-5.767748638548937e-01, -2.576993904719336e-02],
101+
[-6.083364645391598e-01, 6.806321174064961e-01],
102+
]
103+
)
104+
fm1 = np.array(
105+
[
106+
[-4.756392343758577e-01, 7.791666394653051e-01],
107+
[-5.719678320081717e-01, 7.865197061237804e-02],
108+
[-6.682964296404851e-01, -6.218626982406427e-01],
109+
]
110+
)
111+
fm2 = np.array(
112+
[
113+
[-1.922305666539489e-01, 8.924016710972924e-01],
114+
[-5.140779746206554e-01, 2.627873081852611e-01],
115+
[-8.359253825873615e-01, -3.668270547267537e-01],
116+
]
117+
)
118+
assert np.allclose(M.core.data, core)
119+
assert np.allclose(M.u[0], fm0)
120+
assert np.allclose(M.u[1], fm1)
121+
assert np.allclose(M.u[2], fm2)

0 commit comments

Comments
 (0)