Skip to content

Commit d9fa235

Browse files
authored
Merge pull request #292 from bluss/zip-rows
Reimplement rows case in `zip_mut_with` using `Zip`
2 parents f998c70 + 916920e commit d9fa235

File tree

4 files changed

+136
-61
lines changed

4 files changed

+136
-61
lines changed

src/impl_methods.rs

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ use super::ZipExt;
2323
use dimension::IntoDimension;
2424
use dimension::{axes_of, Axes, merge_axes, stride_offset};
2525
use iterators::{
26-
new_inner_iter_smaller,
27-
new_inner_iter_smaller_mut,
26+
new_inners,
27+
new_inners_mut,
2828
whole_chunks_of,
2929
whole_chunks_mut_of,
3030
};
31+
use zip::Zip;
3132

3233
use {
3334
NdIndex,
@@ -1184,33 +1185,18 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
11841185
}
11851186
return;
11861187
}
1188+
11871189
// otherwise, break the arrays up into their inner rows
1188-
let mut try_slices = true;
1189-
let rows = new_inner_iter_smaller_mut(self.view_mut()).zip(
1190-
new_inner_iter_smaller(rhs.view()));
1191-
for (mut s_row, r_row) in rows {
1192-
if try_slices {
1193-
if let Some(self_s) = s_row.as_slice_mut() {
1194-
if let Some(rhs_s) = r_row.as_slice() {
1195-
let len = cmp::min(self_s.len(), rhs_s.len());
1196-
let s = &mut self_s[..len];
1197-
let r = &rhs_s[..len];
1198-
for i in 0..len {
1199-
f(&mut s[i], &r[i]);
1200-
}
1201-
continue;
1202-
}
1203-
}
1204-
try_slices = false;
1205-
}
1206-
unsafe {
1207-
for i in 0..s_row.len() {
1208-
f(s_row.uget_mut(i), r_row.uget(i))
1209-
}
1210-
}
1211-
}
1190+
let n = self.ndim();
1191+
let dim = self.raw_dim();
1192+
Zip::from(new_inners_mut(self.view_mut(), Axis(n - 1)))
1193+
.and(new_inners(rhs.broadcast_assume(dim), Axis(n - 1)))
1194+
.apply(move |s_row, r_row| {
1195+
Zip::from(s_row).and(r_row).apply(|a, b| f(a, b))
1196+
});
12121197
}
12131198

1199+
12141200
fn zip_mut_with_elem<B, F>(&mut self, rhs_elem: &B, mut f: F)
12151201
where S: DataMut,
12161202
F: FnMut(&mut A, &B)
@@ -1432,39 +1418,4 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14321418
}
14331419
})
14341420
}
1435-
1436-
#[cfg(lanes_along)]
1437-
fn lanes_along<'a, F>(&'a self, axis: Axis, mut visit: F)
1438-
where D: RemoveAxis,
1439-
F: FnMut(ArrayView1<'a, A>),
1440-
A: 'a,
1441-
{
1442-
let view_len = self.shape().axis(axis);
1443-
let view_stride = self.strides.axis(axis);
1444-
// use the 0th subview as a map to each 1d array view extended from
1445-
// the 0th element.
1446-
self.subview(axis, 0).visit(move |first_elt| {
1447-
unsafe {
1448-
visit(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
1449-
}
1450-
})
1451-
}
1452-
1453-
#[cfg(lanes_along)]
1454-
fn lanes_along_mut<'a, F>(&'a mut self, axis: Axis, mut visit: F)
1455-
where D: RemoveAxis,
1456-
S: DataMut,
1457-
F: FnMut(ArrayViewMut1<'a, A>),
1458-
A: 'a,
1459-
{
1460-
let view_len = self.shape().axis(axis);
1461-
let view_stride = self.strides.axis(axis);
1462-
// use the 0th subview as a map to each 1d array view extended from
1463-
// the 0th element.
1464-
self.subview_mut(axis, 0).unordered_foreach_mut(move |first_elt| {
1465-
unsafe {
1466-
visit(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
1467-
}
1468-
})
1469-
}
14701421
}

src/iterators/inners.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
2+
use imp_prelude::*;
3+
use {NdProducer, Layout};
4+
5+
impl_ndproducer! {
6+
['a, A, D: Dimension]
7+
[Clone => 'a, A, D: Clone ]
8+
Inners {
9+
base,
10+
inner_len,
11+
inner_stride,
12+
}
13+
Inners<'a, A, D> {
14+
type Dim = D;
15+
type Item = ArrayView<'a, A, Ix1>;
16+
17+
unsafe fn item(&self, ptr) {
18+
ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix))
19+
}
20+
}
21+
}
22+
23+
pub struct Inners<'a, A: 'a, D> {
24+
base: ArrayView<'a, A, D>,
25+
inner_len: Ix,
26+
inner_stride: Ixs,
27+
}
28+
29+
30+
pub fn new_inners<A, D>(v: ArrayView<A, D>, axis: Axis)
31+
-> Inners<A, D::TrySmaller>
32+
where D: Dimension
33+
{
34+
let ndim = v.ndim();
35+
let len;
36+
let stride;
37+
let iter_v;
38+
if ndim == 0 {
39+
len = 1;
40+
stride = 0;
41+
iter_v = v.try_remove_axis(Axis(0))
42+
} else {
43+
len = v.dim.last_elem();
44+
stride = v.strides.last_elem() as isize;
45+
iter_v = v.try_remove_axis(axis)
46+
}
47+
Inners {
48+
inner_len: len,
49+
inner_stride: stride,
50+
base: iter_v,
51+
}
52+
}
53+
54+
impl_ndproducer! {
55+
['a, A, D: Dimension]
56+
[Clone =>]
57+
InnersMut {
58+
base,
59+
inner_len,
60+
inner_stride,
61+
}
62+
InnersMut<'a, A, D> {
63+
type Dim = D;
64+
type Item = ArrayViewMut<'a, A, Ix1>;
65+
66+
unsafe fn item(&self, ptr) {
67+
ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix))
68+
}
69+
}
70+
}
71+
72+
pub struct InnersMut<'a, A: 'a, D> {
73+
base: ArrayViewMut<'a, A, D>,
74+
inner_len: Ix,
75+
inner_stride: Ixs,
76+
}
77+
78+
79+
pub fn new_inners_mut<A, D>(v: ArrayViewMut<A, D>, axis: Axis)
80+
-> InnersMut<A, D::TrySmaller>
81+
where D: Dimension
82+
{
83+
let ndim = v.ndim();
84+
let len;
85+
let stride;
86+
let iter_v;
87+
if ndim == 0 {
88+
len = 1;
89+
stride = 0;
90+
iter_v = v.try_remove_axis(Axis(0))
91+
} else {
92+
len = v.dim.last_elem();
93+
stride = v.strides.last_elem() as isize;
94+
iter_v = v.try_remove_axis(axis)
95+
}
96+
InnersMut {
97+
inner_len: len,
98+
inner_stride: stride,
99+
base: iter_v,
100+
}
101+
}

src/iterators/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#[macro_use] mod macros;
1111
mod chunks;
12+
mod inners;
1213

1314
use std::marker::PhantomData;
1415
use std::ptr;
@@ -35,6 +36,10 @@ pub use self::chunks::{
3536
WholeChunksIterMut,
3637
whole_chunks_mut_of,
3738
};
39+
pub use self::inners::{
40+
new_inners,
41+
new_inners_mut,
42+
};
3843

3944
/// Base for array iterators
4045
///
@@ -527,6 +532,7 @@ impl<'a, A, D> ExactSizeIterator for InnerIterMut<'a, A, D>
527532
}
528533
}
529534

535+
#[cfg(next_version)]
530536
/// Create an InnerIter one dimension smaller than D (if possible)
531537
pub fn new_inner_iter_smaller<A, D>(v: ArrayView<A, D>)
532538
-> InnerIter<A, D::TrySmaller>
@@ -552,6 +558,7 @@ pub fn new_inner_iter_smaller<A, D>(v: ArrayView<A, D>)
552558
}
553559
}
554560

561+
#[cfg(next_version)]
555562
pub fn new_inner_iter_smaller_mut<A, D>(v: ArrayViewMut<A, D>)
556563
-> InnerIterMut<A, D::TrySmaller>
557564
where D: Dimension,

src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,22 @@ impl<A, S, D> ArrayBase<S, D>
590590
}
591591
}
592592

593+
// Broadcast to dimension `E`, without checking that the dimensions match
594+
// (Checked in debug assertions).
595+
#[inline]
596+
fn broadcast_assume<E>(&self, dim: E) -> ArrayView<A, E>
597+
where E: Dimension,
598+
{
599+
let dim = dim.into_dimension();
600+
debug_assert_eq!(self.shape(), dim.slice());
601+
let ptr = self.ptr;
602+
let mut strides = dim.clone();
603+
strides.slice_mut().copy_from_slice(self.strides.slice());
604+
unsafe {
605+
ArrayView::new_(ptr, dim, strides)
606+
}
607+
}
608+
593609
fn raw_strides(&self) -> D {
594610
self.strides.clone()
595611
}

0 commit comments

Comments
 (0)