Skip to content

Commit cebb4a6

Browse files
committed
Cache reshape and transpose
This stores the last few objects returned from reshape and transpose calls. This allows efficiencies from in-place operations like `sum_duplicates` and `sort_indices` to persist in interative workflows. Modern NumPy programmers are accustomed to operations like .transpose() being cheap and aren't accustomed to having to pay sorting costs after many computations. These assumptions are no longer true in sparse by default. However, by caching recent transpose and reshape objects we can reuse their inplace modifications. This greatly accelerates common machine learning workloads.
1 parent 29fec85 commit cebb4a6

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

sparse/core.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True,
154154
assert not self.shape or len(data) == self.coords.shape[1]
155155
self.has_duplicates = has_duplicates
156156
self.sorted = sorted
157+
self._cache = defaultdict(lambda: deque(maxlen=3))
157158

158159
@classmethod
159160
def from_numpy(cls, x):
@@ -209,6 +210,9 @@ def __getitem__(self, index):
209210
index = (index,)
210211
index = tuple(ind + self.shape[i] if isinstance(ind, numbers.Integral) and ind < 0 else ind
211212
for i, ind in enumerate(index))
213+
if (all(ind == slice(None) or ind == slice(0, d)
214+
for ind, d in zip(index, self.shape))):
215+
return self
212216
mask = np.ones(self.nnz, dtype=bool)
213217
for i, ind in enumerate([i for i in index if i is not None]):
214218
if ind == slice(None, None):
@@ -325,19 +329,15 @@ def transpose(self, axes=None):
325329
if axes == tuple(range(self.ndim)):
326330
return self
327331

332+
for ax, value in self._cache['transpose']:
333+
if ax == axes:
334+
return value
335+
328336
shape = tuple(self.shape[ax] for ax in axes)
329337
result = COO(self.coords[axes, :], self.data, shape,
330338
has_duplicates=self.has_duplicates)
331339

332-
if axes == (1, 0):
333-
try:
334-
result._csc = self._csr.T
335-
except AttributeError:
336-
pass
337-
try:
338-
result._csr = self._csc.T
339-
except AttributeError:
340-
pass
340+
self._cache['transpose'].append((axes, result))
341341
return result
342342

343343
@property
@@ -374,8 +374,14 @@ def reshape(self, shape):
374374
extra = int(np.prod(self.shape) /
375375
np.prod([d for d in shape if d != -1]))
376376
shape = tuple([d if d != -1 else extra for d in shape])
377+
377378
if self.shape == shape:
378379
return self
380+
381+
for sh, value in self._cache['reshape']:
382+
if sh == shape:
383+
return value
384+
379385
# TODO: this np.prod(self.shape) enforces a 2**64 limit to array size
380386
linear_loc = self.linear_loc()
381387

@@ -385,9 +391,12 @@ def reshape(self, shape):
385391
coords[-(i + 1), :] = (linear_loc // strides) % d
386392
strides *= d
387393

388-
return COO(coords, self.data, shape,
389-
has_duplicates=self.has_duplicates,
390-
sorted=self.sorted)
394+
result = COO(coords, self.data, shape,
395+
has_duplicates=self.has_duplicates,
396+
sorted=self.sorted)
397+
398+
self._cache['reshape'].append((shape, result))
399+
return result
391400

392401
def to_scipy_sparse(self):
393402
assert self.ndim == 2

sparse/tests/test_core.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,6 @@ def test_cache_csr():
418418
assert s.tocsr() is s.tocsr()
419419
assert s.tocsc() is s.tocsc()
420420

421-
st = s.T
422-
423-
assert_eq(st._csr, st)
424-
assert_eq(st._csc, st)
425-
426-
assert isinstance(st.tocsr(), scipy.sparse.csr_matrix)
427-
assert isinstance(st.tocsc(), scipy.sparse.csc_matrix)
428-
429421

430422
def test_empty_shape():
431423
x = COO([], [1.0])
@@ -469,3 +461,16 @@ def test_add_many_sparse_arrays():
469461
x = COO({(1, 1): 1})
470462
y = sum([x] * 100)
471463
assert y.nnz < np.prod(y.shape)
464+
465+
466+
def test_caching():
467+
x = COO({(10, 10, 10): 1})
468+
469+
assert x[:].reshape((100, 10)).transpose().tocsr() is x[:].reshape((100, 10)).transpose().tocsr()
470+
471+
x = COO({(1, 1, 1, 1, 1, 1, 1, 2): 1})
472+
473+
for i in range(x.ndim):
474+
x.reshape((1,) * i + (2,) + (1,) * (x.ndim - i - 1))
475+
476+
assert len(x._cache['reshape']) < 5

0 commit comments

Comments
 (0)