-
Notifications
You must be signed in to change notification settings - Fork 323
Argmax/margmin for Array1 #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
It'd be useful to have a general argmin/argmax for |
What's a universal max, min, sum? Just curious. On the topic of sum, the plan is that current Array method |
The sum/stdev/variance/min/max/etc. operations reduce the dimensionality of the array by one for each axis being iterated over (since they remove that axis). For example, NumPy behaves like this (where the >>> import numpy as np
>>> x = np.zeros((2, 3, 4, 5))
>>> x.sum(axis=None).shape
()
>>> x.sum(axis=2).shape
(2, 3, 5)
>>> x.sum(axis=(1, 2)).shape
(2, 5) I'd suggest using a trait to handle the various types of axes arguments: pub trait FoldAxes<A, D: Dimension> {
type Output;
type Repr: AsRef<[isize]>;
/// Should return `None` if fold is over all axes.
/// Instead of `Option<Self::Repr>` here, it could just be `Self::Repr`, where
/// a repr of length zero would mean "all axes".
fn into_repr(self) -> Option<Self::Repr>;
}
// Scalar output case (all axes).
impl<A, D: Dimension> FoldAxes<A, D> for () {
type Output = A;
type Repr = &'static [isize];
fn into_repr(self) -> Option<&'static [isize]> {
None
}
}
// Iterate over single axis.
impl<A, D: Dimension> FoldAxes<A, D> for isize {
type Output = Array<A, D::Smaller>;
type Repr = [isize; 1];
fn into_repr(self) -> Option<[isize; 1]> {
Some([self])
}
}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize,) {...}
impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize) {
type Output = Array<A, <<D as Dimension>::Smaller as Dimension>::Smaller>;
type Repr = [isize; 2];
fn into_repr(self) -> Option<[isize; 2]> {
Some([self.0, self.1])
}
}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 1] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 2] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 3] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 4] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 5] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 6] {...}
impl<'a, A, D: Dimension> FoldAxes<A, D> for &'a [isize] {
type Output = Array<A, IxDyn>;
type Repr = &'a [isize];
fn into_repr(self) -> Option<&'a [isize]> {
Some(self)
}
}
impl<A, D: Dimension> FoldAxes<A, D> for Vec<isize> {
type Output = Array<A, IxDyn>;
type Repr = Vec<isize>;
fn into_repr(self) -> Option<Vec<isize>> {
Some(self)
}
}
// same for `usize`, tuples of `usize`, slices of `usize`, vecs of `usize` (casting to `usize` to `isize`) Then, for example, impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn sum<T>(&self, axes: T) -> <T as FoldAxes<A, D>>::Output
where
A: Clone + Add<A, Output=A>,
T: FoldAxes<A, D>,
{
// (or check zero-length here if not using the Option approach)
let axes_repr: Option<T::Repr> = axes.into_repr();
if let Some(axes) = axes {
let axes_slice: &[isize] = axes.as_ref();
// operate over specified axes...
} else {
// operate over all axes...
}
}
} which you could call like // Sum over all axes.
arr.sum(());
// Sum over axis 1.
arr.sum(1);
// Sum over axes 2 and 3.
arr.sum((2, 3));
// Sum over variable number of axes.
arr.sum(axes_vec); Note that impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn argmin<T>(&self, axes: T) -> <T as FoldAxes<usize, D>>::Output
where
for<'a> &'a A: PartialOrd<&'a A>,
T: FoldAxes<usize, D>,
{
// ...
}
} You could reduce duplication of code in the method implementations by implementing them in terms of a generalized Edit: For what it's worth, if all you want to support is You could also have Edit 2: I remembered that computing the standard deviation requires multiple numerical accumulators. So, instead of impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn fold_map_axes<'a, T, B, F, M, O>(&'a self, axes: T, init: B, fold: F, map: M) -> Array<O, <T as FoldAxes<D>>::OutDim>
where
A: 'a,
T: FoldAxes<D>,
F: FnMut(B, &'a A) -> B,
M: FnMut(B) -> O,
{
// ...
}
} |
Any updates? This is frequently used... Thanks! |
this is available in ndarray-stats |
@Tootallgavin This does not seem to be able to get max elements along any axis. |
You mean an |
It would be useful to have a function like
https://athemathmo.github.io/rulinalg/doc/rulinalg/utils/fn.argmax.html
The text was updated successfully, but these errors were encountered: