Skip to content

Commit edda2f0

Browse files
committed
Strongly typed Axis argument (newtype called Axis)
1 parent 0aab138 commit edda2f0

File tree

7 files changed

+125
-89
lines changed

7 files changed

+125
-89
lines changed

examples/axis.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
extern crate ndarray;
2+
3+
use ndarray::{
4+
OwnedArray,
5+
Axis,
6+
};
7+
8+
fn main() {
9+
let a = OwnedArray::<f32, _>::linspace(0., 24., 25).into_shape((5, 5)).unwrap();
10+
println!("{:?}", a.subview(Axis(0), 0));
11+
println!("{:?}", a.subview(Axis(0), 1));
12+
println!("{:?}", a.subview(Axis(1), 1));
13+
println!("{:?}", a.subview(Axis(0), 1));
14+
println!("{:?}", a.subview(Axis(2), 1)); // PANIC
15+
}

src/dimension.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,13 @@ mod test {
715715
assert!(super::dim_stride_overlap(&dim, &strides));
716716
}
717717
}
718+
719+
/// An axis index.
720+
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
721+
pub struct Axis(pub usize);
722+
723+
impl Axis {
724+
#[inline(always)]
725+
pub fn axis(&self) -> usize { self.0 }
726+
}
727+

src/lib.rs

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ use itertools::free::enumerate;
9292
pub use dimension::{
9393
Dimension,
9494
RemoveAxis,
95+
Axis,
9596
};
9697

9798
pub use dimension::NdIndex;
@@ -292,7 +293,7 @@ pub type Ixs = isize;
292293
/// Subview takes two arguments: `axis` and `index`.
293294
///
294295
/// ```
295-
/// use ndarray::{arr3, aview2};
296+
/// use ndarray::{arr3, aview2, Axis};
296297
///
297298
/// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`.
298299
///
@@ -308,8 +309,8 @@ pub type Ixs = isize;
308309
/// // Let’s take a subview along the greatest dimension (axis 0),
309310
/// // taking submatrix 0, then submatrix 1
310311
///
311-
/// let sub_0 = a.subview(0, 0);
312-
/// let sub_1 = a.subview(0, 1);
312+
/// let sub_0 = a.subview(Axis(0), 0);
313+
/// let sub_1 = a.subview(Axis(0), 1);
313314
///
314315
/// assert_eq!(sub_0, aview2(&[[ 1, 2, 3],
315316
/// [ 4, 5, 6]]));
@@ -318,7 +319,7 @@ pub type Ixs = isize;
318319
/// assert_eq!(sub_0.shape(), &[2, 3]);
319320
///
320321
/// // This is the subview picking only axis 2, column 0
321-
/// let sub_col = a.subview(2, 0);
322+
/// let sub_col = a.subview(Axis(2), 0);
322323
///
323324
/// assert_eq!(sub_col, aview2(&[[ 1, 4],
324325
/// [ 7, 10]]));
@@ -1265,7 +1266,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
12651266
/// **Panics** if `axis` or `index` is out of bounds.
12661267
///
12671268
/// ```
1268-
/// use ndarray::{arr1, arr2};
1269+
/// use ndarray::{arr1, arr2, Axis};
12691270
///
12701271
/// let a = arr2(&[[1., 2.], // -- axis 0, row 0
12711272
/// [3., 4.], // -- axis 0, row 1
@@ -1274,13 +1275,13 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
12741275
/// // \ axis 1, column 1
12751276
/// // axis 1, column 0
12761277
/// assert!(
1277-
/// a.subview(0, 1) == arr1(&[3., 4.]) &&
1278-
/// a.subview(1, 1) == arr1(&[2., 4., 6.])
1278+
/// a.subview(Axis(0), 1) == arr1(&[3., 4.]) &&
1279+
/// a.subview(Axis(1), 1) == arr1(&[2., 4., 6.])
12791280
/// );
12801281
/// ```
1281-
pub fn subview(&self, axis: usize, index: Ix)
1282+
pub fn subview(&self, axis: Axis, index: Ix)
12821283
-> ArrayView<A, <D as RemoveAxis>::Smaller>
1283-
where D: RemoveAxis
1284+
where D: RemoveAxis,
12841285
{
12851286
self.view().into_subview(axis, index)
12861287
}
@@ -1291,19 +1292,19 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
12911292
/// **Panics** if `axis` or `index` is out of bounds.
12921293
///
12931294
/// ```
1294-
/// use ndarray::{arr2, aview2};
1295+
/// use ndarray::{arr2, aview2, Axis};
12951296
///
12961297
/// let mut a = arr2(&[[1., 2.],
12971298
/// [3., 4.]]);
12981299
///
1299-
/// a.subview_mut(1, 1).iadd_scalar(&10.);
1300+
/// a.subview_mut(Axis(1), 1).iadd_scalar(&10.);
13001301
///
13011302
/// assert!(
13021303
/// a == aview2(&[[1., 12.],
13031304
/// [3., 14.]])
13041305
/// );
13051306
/// ```
1306-
pub fn subview_mut(&mut self, axis: usize, index: Ix)
1307+
pub fn subview_mut(&mut self, axis: Axis, index: Ix)
13071308
-> ArrayViewMut<A, D::Smaller>
13081309
where S: DataMut,
13091310
D: RemoveAxis,
@@ -1315,19 +1316,21 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
13151316
/// and select the subview of `index` along that axis.
13161317
///
13171318
/// **Panics** if `index` is past the length of the axis.
1318-
pub fn isubview(&mut self, axis: usize, index: Ix) {
1319-
dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides, axis, index)
1319+
pub fn isubview(&mut self, axis: Axis, index: Ix) {
1320+
dimension::do_sub(&mut self.dim, &mut self.ptr, &self.strides,
1321+
axis.axis(), index)
13201322
}
13211323

13221324
/// Along `axis`, select the subview `index` and return `self`
13231325
/// with that axis removed.
13241326
///
13251327
/// See [`.subview()`](#method.subview) and [*Subviews*](#subviews) for full documentation.
1326-
pub fn into_subview(mut self, axis: usize, index: Ix)
1328+
pub fn into_subview(mut self, axis: Axis, index: Ix)
13271329
-> ArrayBase<S, <D as RemoveAxis>::Smaller>
1328-
where D: RemoveAxis
1330+
where D: RemoveAxis,
13291331
{
13301332
self.isubview(axis, index);
1333+
let axis = axis.axis();
13311334
// don't use reshape -- we always know it will fit the size,
13321335
// and we can use remove_axis on the strides as well
13331336
ArrayBase {
@@ -1379,15 +1382,16 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
13791382
/// Iterator element is `ArrayView<A, D::Smaller>` (read-only array view).
13801383
///
13811384
/// ```
1382-
/// use ndarray::arr3;
1385+
/// use ndarray::{arr3, Axis};
1386+
///
13831387
/// let a = arr3(&[[[ 0, 1, 2], // \ axis 0, submatrix 0
13841388
/// [ 3, 4, 5]], // /
13851389
/// [[ 6, 7, 8], // \ axis 0, submatrix 1
13861390
/// [ 9, 10, 11]]]); // /
13871391
/// // `outer_iter` yields the two submatrices along axis 0.
13881392
/// let mut iter = a.outer_iter();
1389-
/// assert_eq!(iter.next().unwrap(), a.subview(0, 0));
1390-
/// assert_eq!(iter.next().unwrap(), a.subview(0, 1));
1393+
/// assert_eq!(iter.next().unwrap(), a.subview(Axis(0), 0));
1394+
/// assert_eq!(iter.next().unwrap(), a.subview(Axis(0), 1));
13911395
/// ```
13921396
pub fn outer_iter(&self) -> OuterIter<A, D::Smaller>
13931397
where D: RemoveAxis,
@@ -1418,10 +1422,10 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14181422
/// See [*Subviews*](#subviews) for full documentation.
14191423
///
14201424
/// **Panics** if `axis` is out of bounds.
1421-
pub fn axis_iter(&self, axis: usize) -> OuterIter<A, D::Smaller>
1422-
where D: RemoveAxis
1425+
pub fn axis_iter(&self, axis: Axis) -> OuterIter<A, D::Smaller>
1426+
where D: RemoveAxis,
14231427
{
1424-
iterators::new_axis_iter(self.view(), axis)
1428+
iterators::new_axis_iter(self.view(), axis.axis())
14251429
}
14261430

14271431

@@ -1432,11 +1436,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14321436
/// (read-write array view).
14331437
///
14341438
/// **Panics** if `axis` is out of bounds.
1435-
pub fn axis_iter_mut(&mut self, axis: usize) -> OuterIterMut<A, D::Smaller>
1439+
pub fn axis_iter_mut(&mut self, axis: Axis) -> OuterIterMut<A, D::Smaller>
14361440
where S: DataMut,
14371441
D: RemoveAxis,
14381442
{
1439-
iterators::new_axis_iter_mut(self.view_mut(), axis)
1443+
iterators::new_axis_iter_mut(self.view_mut(), axis.axis())
14401444
}
14411445

14421446
/// Return an iterator that traverses over `axis` by chunks of `size`,
@@ -1451,20 +1455,22 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14511455
///
14521456
/// ```
14531457
/// use ndarray::OwnedArray;
1454-
/// use ndarray::arr3;
1458+
/// use ndarray::{arr3, Axis};
14551459
///
14561460
/// let a = OwnedArray::from_iter(0..28).into_shape((2, 7, 2)).unwrap();
1457-
/// let mut iter = a.axis_chunks_iter(1, 2);
1461+
/// let mut iter = a.axis_chunks_iter(Axis(1), 2);
14581462
///
14591463
/// // first iteration yields a 2 × 2 × 2 view
14601464
/// assert_eq!(iter.next().unwrap(),
1461-
/// arr3(&[[[0, 1], [2, 3]], [[14, 15], [16, 17]]]));
1465+
/// arr3(&[[[ 0, 1], [ 2, 3]],
1466+
/// [[14, 15], [16, 17]]]));
14621467
///
14631468
/// // however the last element is a 2 × 1 × 2 view since 7 % 2 == 1
1464-
/// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]], [[26, 27]]]));
1469+
/// assert_eq!(iter.next_back().unwrap(), arr3(&[[[12, 13]],
1470+
/// [[26, 27]]]));
14651471
/// ```
1466-
pub fn axis_chunks_iter(&self, axis: usize, size: usize) -> AxisChunksIter<A, D> {
1467-
iterators::new_chunk_iter(self.view(), axis, size)
1472+
pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<A, D> {
1473+
iterators::new_chunk_iter(self.view(), axis.axis(), size)
14681474
}
14691475

14701476
/// Return an iterator that traverses over `axis` by chunks of `size`,
@@ -1473,11 +1479,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
14731479
/// Iterator element is `ArrayViewMut<A, D>`
14741480
///
14751481
/// **Panics** if `axis` is out of bounds.
1476-
pub fn axis_chunks_iter_mut(&mut self, axis: usize, size: usize)
1482+
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize)
14771483
-> AxisChunksIterMut<A, D>
14781484
where S: DataMut
14791485
{
1480-
iterators::new_chunk_iter_mut(self.view_mut(), axis, size)
1486+
iterators::new_chunk_iter_mut(self.view_mut(), axis.axis(), size)
14811487
}
14821488

14831489
// Return (length, stride) for diagonal
@@ -2229,24 +2235,24 @@ impl<A, S, D> ArrayBase<S, D>
22292235
/// Return sum along `axis`.
22302236
///
22312237
/// ```
2232-
/// use ndarray::{aview0, aview1, arr2};
2238+
/// use ndarray::{aview0, aview1, arr2, Axis};
22332239
///
22342240
/// let a = arr2(&[[1., 2.],
22352241
/// [3., 4.]]);
22362242
/// assert!(
2237-
/// a.sum(0) == aview1(&[4., 6.]) &&
2238-
/// a.sum(1) == aview1(&[3., 7.]) &&
2243+
/// a.sum(Axis(0)) == aview1(&[4., 6.]) &&
2244+
/// a.sum(Axis(1)) == aview1(&[3., 7.]) &&
22392245
///
2240-
/// a.sum(0).sum(0) == aview0(&10.)
2246+
/// a.sum(Axis(0)).sum(Axis(0)) == aview0(&10.)
22412247
/// );
22422248
/// ```
22432249
///
22442250
/// **Panics** if `axis` is out of bounds.
2245-
pub fn sum(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
2251+
pub fn sum(&self, axis: Axis) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
22462252
where A: Clone + Add<Output=A>,
22472253
D: RemoveAxis,
22482254
{
2249-
let n = self.shape()[axis];
2255+
let n = self.shape()[axis.axis()];
22502256
let mut res = self.subview(axis, 0).to_owned();
22512257
for i in 1..n {
22522258
let view = self.subview(axis, i);
@@ -2283,24 +2289,23 @@ impl<A, S, D> ArrayBase<S, D>
22832289

22842290
/// Return mean along `axis`.
22852291
///
2292+
/// **Panics** if `axis` is out of bounds.
2293+
///
22862294
/// ```
2287-
/// use ndarray::{aview1, arr2};
2295+
/// use ndarray::{aview1, arr2, Axis};
22882296
///
22892297
/// let a = arr2(&[[1., 2.],
22902298
/// [3., 4.]]);
22912299
/// assert!(
2292-
/// a.mean(0) == aview1(&[2.0, 3.0]) &&
2293-
/// a.mean(1) == aview1(&[1.5, 3.5])
2300+
/// a.mean(Axis(0)) == aview1(&[2.0, 3.0]) &&
2301+
/// a.mean(Axis(1)) == aview1(&[1.5, 3.5])
22942302
/// );
22952303
/// ```
2296-
///
2297-
///
2298-
/// **Panics** if `axis` is out of bounds.
2299-
pub fn mean(&self, axis: usize) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
2304+
pub fn mean(&self, axis: Axis) -> OwnedArray<A, <D as RemoveAxis>::Smaller>
23002305
where A: LinalgScalar,
23012306
D: RemoveAxis,
23022307
{
2303-
let n = self.shape()[axis];
2308+
let n = self.shape()[axis.axis()];
23042309
let mut sum = self.sum(axis);
23052310
let one = libnum::one::<A>();
23062311
let mut cnt = one;
@@ -2413,7 +2418,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24132418
/// **Panics** if `index` is out of bounds.
24142419
pub fn row(&self, index: Ix) -> ArrayView<A, Ix>
24152420
{
2416-
self.subview(0, index)
2421+
self.subview(Axis(0), index)
24172422
}
24182423

24192424
/// Return a mutable array view of row `index`.
@@ -2422,15 +2427,15 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24222427
pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut<A, Ix>
24232428
where S: DataMut
24242429
{
2425-
self.subview_mut(0, index)
2430+
self.subview_mut(Axis(0), index)
24262431
}
24272432

24282433
/// Return an array view of column `index`.
24292434
///
24302435
/// **Panics** if `index` is out of bounds.
24312436
pub fn column(&self, index: Ix) -> ArrayView<A, Ix>
24322437
{
2433-
self.subview(1, index)
2438+
self.subview(Axis(1), index)
24342439
}
24352440

24362441
/// Return a mutable array view of column `index`.
@@ -2439,7 +2444,7 @@ impl<A, S> ArrayBase<S, (Ix, Ix)>
24392444
pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut<A, Ix>
24402445
where S: DataMut
24412446
{
2442-
self.subview_mut(1, index)
2447+
self.subview_mut(Axis(1), index)
24432448
}
24442449

24452450
/// Perform matrix multiplication of rectangular arrays `self` and `rhs`.

0 commit comments

Comments
 (0)