Skip to content

Commit ec8e7e5

Browse files
authored
Merge pull request #40 from zm711/master
Change `np.in1d` into `np.isin` as `np.in1d` will be deprecated with new versions of NumPy
2 parents 2b567cd + 099fed6 commit ec8e7e5

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

phylib/io/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def _spikes_in_clusters(spike_clusters, clusters):
328328
"""Return the ids of all spikes belonging to the specified clusters."""
329329
if len(spike_clusters) == 0 or len(clusters) == 0:
330330
return np.array([], dtype=int)
331-
return np.nonzero(np.in1d(spike_clusters, clusters))[0]
331+
return np.nonzero(np.isin(spike_clusters, clusters))[0]
332332

333333

334334
def _spikes_per_cluster(spike_clusters, spike_ids=None):

phylib/io/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def from_sparse(data, cols, channel_ids):
9090
# NOTE: we ensure here that `col` contains integers.
9191
c = cols.flatten().astype(np.int32)
9292
# Remove columns that do not belong to the specified channels.
93-
c[~np.in1d(c, channel_ids)] = -1
94-
assert np.all(np.in1d(c, np.r_[channel_ids, -1]))
93+
c[~np.isin(c, channel_ids)] = -1
94+
assert np.all(np.isin(c, np.r_[channel_ids, -1]))
9595
# Convert column indices to relative indices given the specified
9696
# channel_ids.
9797
cols_loc = _index_of(c, np.r_[channel_ids, -1]).reshape(cols.shape)

phylib/io/tests/test_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_spikes_in_clusters():
288288
assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, [i])] == i)
289289

290290
clusters = [1, 2, 3]
291-
assert np.all(np.in1d(
291+
assert np.all(np.isin(
292292
spike_clusters[_spikes_in_clusters(spike_clusters, clusters)], clusters))
293293

294294

phylib/stats/tests/test_clusters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_sorted_main_channels(masks):
103103
mean_masks = mean(masks)
104104
channels = get_sorted_main_channels(mean_masks,
105105
get_unmasked_channels(mean_masks))
106-
assert np.all(np.in1d(channels, [5, 7]))
106+
assert np.all(np.isin(channels, [5, 7]))
107107

108108

109109
def test_waveform_amplitude(masks, waveforms):

0 commit comments

Comments
 (0)