Skip to content

Commit 6d26158

Browse files
authored
Tensor generator helpers (#93)
* TENONES: Add initial tenones support * TENZEROS: Add initial tenzeros support * TENDIAG: Add initial tendiag support * SPTENDIAG: Add initial sptendiag support
1 parent ab3b410 commit 6d26158

File tree

6 files changed

+198
-4
lines changed

6 files changed

+198
-4
lines changed

pyttb/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from pyttb.ktensor import ktensor
1616
from pyttb.pyttb_utils import *
1717
from pyttb.sptenmat import sptenmat
18-
from pyttb.sptensor import sptensor
18+
from pyttb.sptensor import sptendiag, sptensor
1919
from pyttb.sptensor3 import sptensor3
2020
from pyttb.sumtensor import sumtensor
2121
from pyttb.symktensor import symktensor
2222
from pyttb.symtensor import symtensor
2323
from pyttb.tenmat import tenmat
24-
from pyttb.tensor import tensor
24+
from pyttb.tensor import tendiag, tenones, tensor, tenzeros
2525
from pyttb.ttensor import ttensor
2626
from pyttb.tucker_als import tucker_als
2727

pyttb/pyttb_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def tt_subscheck(subs, nargout=True):
700700
len(subs.shape) == 2
701701
and (np.isfinite(subs)).all()
702702
and issubclass(subs.dtype.type, np.integer)
703-
and (subs > 0).all()
703+
and (subs >= 0).all()
704704
):
705705
ok = True
706706
else:

pyttb/sptensor.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def from_aggregator(
297297

298298
# Check for subscripts out of range
299299
for j, dim in enumerate(shape):
300-
if subs.size > 0 and np.max(subs[:, j]) > dim:
300+
if subs.size > 0 and np.max(subs[:, j]) >= dim:
301301
assert False, "Subscript exceeds sptensor shape"
302302

303303
if subs.size == 0:
@@ -2582,3 +2582,39 @@ def ttm(
25822582
# TODO evaluate performance loss by casting into sptensor then tensor.
25832583
# I assume minimal since we are already using spare matrix representation
25842584
return ttb.tensor.from_tensor_type(Ynt)
2585+
2586+
2587+
def sptendiag(
2588+
elements: np.ndarray, shape: Optional[Tuple[int, ...]] = None
2589+
) -> sptensor:
2590+
"""
2591+
Creates a sparse tensor with elements along super diagonal
2592+
If provided shape is too small the tensor will be enlarged to accomodate
2593+
2594+
Parameters
2595+
----------
2596+
elements: Elements to set along the diagonal
2597+
shape: Shape of resulting tensor
2598+
2599+
Returns
2600+
-------
2601+
Constructed tensor
2602+
2603+
Example
2604+
-------
2605+
>>> shape = (2,)
2606+
>>> values = np.ones(shape)
2607+
>>> X = ttb.sptendiag(values)
2608+
>>> Y = ttb.sptendiag(values, (2, 2))
2609+
>>> X.isequal(Y)
2610+
True
2611+
"""
2612+
# Flatten provided elements
2613+
elements = np.ravel(elements)
2614+
N = len(elements)
2615+
if shape is None:
2616+
constructed_shape = (N,) * N
2617+
else:
2618+
constructed_shape = tuple(max(N, dim) for dim in shape)
2619+
subs = np.tile(np.arange(0, N).transpose(), (len(constructed_shape), 1)).transpose()
2620+
return sptensor.from_aggregator(subs, elements.reshape((N, 1)), constructed_shape)

pyttb/tensor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,80 @@ def __repr__(self):
18351835
__str__ = __repr__
18361836

18371837

1838+
def tenones(shape: Tuple[int, ...]) -> tensor:
1839+
"""
1840+
Creates a tensor of all ones
1841+
1842+
Parameters
1843+
----------
1844+
shape: Shape of resulting tensor
1845+
1846+
Returns
1847+
-------
1848+
Constructed tensor
1849+
1850+
Example
1851+
-------
1852+
>>> X = ttb.tenones((2,2))
1853+
"""
1854+
return tensor.from_function(np.ones, shape)
1855+
1856+
1857+
def tenzeros(shape: Tuple[int, ...]) -> tensor:
1858+
"""
1859+
Creates a tensor of all zeros
1860+
1861+
Parameters
1862+
----------
1863+
shape: Shape of resulting tensor
1864+
1865+
Returns
1866+
-------
1867+
Constructed tensor
1868+
1869+
Example
1870+
-------
1871+
>>> X = ttb.tenzeros((2,2))
1872+
"""
1873+
return tensor.from_function(np.zeros, shape)
1874+
1875+
1876+
def tendiag(elements: np.ndarray, shape: Optional[Tuple[int, ...]] = None) -> tensor:
1877+
"""
1878+
Creates a tensor with elements along super diagonal
1879+
If provided shape is too small the tensor will be enlarged to accomodate
1880+
1881+
Parameters
1882+
----------
1883+
elements: Elements to set along the diagonal
1884+
shape: Shape of resulting tensor
1885+
1886+
Returns
1887+
-------
1888+
Constructed tensor
1889+
1890+
Example
1891+
-------
1892+
>>> shape = (2,)
1893+
>>> values = np.ones(shape)
1894+
>>> X = ttb.tendiag(values)
1895+
>>> Y = ttb.tendiag(values, (2, 2))
1896+
>>> X.isequal(Y)
1897+
True
1898+
"""
1899+
# Flatten provided elements
1900+
elements = np.ravel(elements)
1901+
N = len(elements)
1902+
if shape is None:
1903+
constructed_shape = (N,) * N
1904+
else:
1905+
constructed_shape = tuple(max(N, dim) for dim in shape)
1906+
X = tenzeros(constructed_shape)
1907+
subs = np.tile(np.arange(0, N).transpose(), (len(constructed_shape), 1))
1908+
X[subs] = elements
1909+
return X
1910+
1911+
18381912
if __name__ == "__main__":
18391913
import doctest # pragma: no cover
18401914

tests/test_sptensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,3 +1758,39 @@ def test_sptensor_from_sparse_matrix():
17581758
Xnt = tt_to_sparse_matrix(sptensorCopy, mode, False)
17591759
Ynt = tt_from_sparse_matrix(Xnt, sptensorCopy.shape, mode, 1)
17601760
assert sptensorCopy.isequal(Ynt)
1761+
1762+
1763+
def test_sptendiag():
1764+
N = 4
1765+
elements = np.arange(0, N)
1766+
exact_shape = [N] * N
1767+
1768+
# Inferred shape
1769+
X = ttb.sptendiag(elements)
1770+
for i in range(N):
1771+
diag_index = (i,) * N
1772+
assert (
1773+
X[diag_index] == i
1774+
), f"Idx: {diag_index} expected: {i} got: {X[diag_index]}"
1775+
1776+
# Exact shape
1777+
X = ttb.sptendiag(elements, tuple(exact_shape))
1778+
for i in range(N):
1779+
diag_index = (i,) * N
1780+
assert X[diag_index] == i
1781+
1782+
# Larger shape
1783+
larger_shape = exact_shape.copy()
1784+
larger_shape[0] += 1
1785+
X = ttb.sptendiag(elements, tuple(larger_shape))
1786+
for i in range(N):
1787+
diag_index = (i,) * N
1788+
assert X[diag_index] == i
1789+
1790+
# Smaller Shape
1791+
smaller_shape = exact_shape.copy()
1792+
smaller_shape[0] -= 1
1793+
X = ttb.sptendiag(elements, tuple(smaller_shape))
1794+
for i in range(N):
1795+
diag_index = (i,) * N
1796+
assert X[diag_index] == i

tests/test_tensor.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,3 +1632,51 @@ def test_tensor_nvecs(sample_tensor_2way):
16321632
"Greater than or equal to tensor.shape[n] - 1 eigenvectors requires cast to dense to solve"
16331633
in str(record[0].message)
16341634
)
1635+
1636+
1637+
def test_tenones():
1638+
arbitrary_shape = (3, 3, 3)
1639+
ones_tensor = ttb.tenones(arbitrary_shape)
1640+
data_tensor = ttb.tensor.from_data(np.ones(arbitrary_shape))
1641+
assert np.equal(ones_tensor, data_tensor), "Tenones should match all ones tensor"
1642+
1643+
1644+
def test_tenzeros():
1645+
arbitrary_shape = (3, 3, 3)
1646+
zeros_tensor = ttb.tenzeros(arbitrary_shape)
1647+
data_tensor = ttb.tensor.from_data(np.zeros(arbitrary_shape))
1648+
assert np.equal(zeros_tensor, data_tensor), "Tenzeros should match all zeros tensor"
1649+
1650+
1651+
def test_tendiag():
1652+
N = 4
1653+
elements = np.arange(0, N)
1654+
exact_shape = [N] * N
1655+
1656+
# Inferred shape
1657+
X = ttb.tendiag(elements)
1658+
for i in range(N):
1659+
diag_index = (i,) * N
1660+
assert X[diag_index] == i
1661+
1662+
# Exact shape
1663+
X = ttb.tendiag(elements, tuple(exact_shape))
1664+
for i in range(N):
1665+
diag_index = (i,) * N
1666+
assert X[diag_index] == i
1667+
1668+
# Larger shape
1669+
larger_shape = exact_shape.copy()
1670+
larger_shape[0] += 1
1671+
X = ttb.tendiag(elements, tuple(larger_shape))
1672+
for i in range(N):
1673+
diag_index = (i,) * N
1674+
assert X[diag_index] == i
1675+
1676+
# Smaller Shape
1677+
smaller_shape = exact_shape.copy()
1678+
smaller_shape[0] -= 1
1679+
X = ttb.tendiag(elements, tuple(smaller_shape))
1680+
for i in range(N):
1681+
diag_index = (i,) * N
1682+
assert X[diag_index] == i

0 commit comments

Comments
 (0)