Skip to content

Argmax/margmin for Array1 #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
robsmith11 opened this issue Feb 9, 2018 · 8 comments
Open

Argmax/margmin for Array1 #416

robsmith11 opened this issue Feb 9, 2018 · 8 comments

Comments

@robsmith11
Copy link

It would be useful to have a function like
https://athemathmo.github.io/rulinalg/doc/rulinalg/utils/fn.argmax.html

@ehsanmok
Copy link
Contributor

ehsanmok commented Feb 19, 2018

It'd be useful to have a general argmin/argmax for ArrayBase, also it seems there're no universal min, max, sum available for convenience!

@bluss
Copy link
Member

bluss commented Feb 23, 2018

What's a universal max, min, sum? Just curious.

On the topic of sum, the plan is that current Array method .scalar_sum() is renamed to .sum(). The plan is to also add methods like .std() and .std_axis() (for standard dev) etc.. in that style. Any innovative ideas about how to design the interface in a language where return types are statically known and we don't have default arguments, they are very welcome of course.

@ehsanmok
Copy link
Contributor

@bluss by universal, I meant to make the difference between sum of all elements vs. sum_axis for example. Please see this as well.

@jturner314
Copy link
Member

jturner314 commented Feb 23, 2018

The sum/stdev/variance/min/max/etc. operations reduce the dimensionality of the array by one for each axis being iterated over (since they remove that axis). For example, NumPy behaves like this (where the axis argument is optional):

>>> import numpy as np
>>> x = np.zeros((2, 3, 4, 5))
>>> x.sum(axis=None).shape
()
>>> x.sum(axis=2).shape
(2, 3, 5)
>>> x.sum(axis=(1, 2)).shape
(2, 5)

I'd suggest using a trait to handle the various types of axes arguments:

pub trait FoldAxes<A, D: Dimension> {
    type Output;
    type Repr: AsRef<[isize]>;
    /// Should return `None` if fold is over all axes.
    /// Instead of `Option<Self::Repr>` here, it could just be `Self::Repr`, where
    /// a repr of length zero would mean "all axes".
    fn into_repr(self) -> Option<Self::Repr>;
}

// Scalar output case (all axes).
impl<A, D: Dimension> FoldAxes<A, D> for () {
    type Output = A;
    type Repr = &'static [isize];
    fn into_repr(self) -> Option<&'static [isize]> {
        None
    }
}

// Iterate over single axis.
impl<A, D: Dimension> FoldAxes<A, D> for isize {
    type Output = Array<A, D::Smaller>;
    type Repr = [isize; 1];
    fn into_repr(self) -> Option<[isize; 1]> {
        Some([self])
    }
}

// impl<A, D: Dimension> FoldAxes<A, D> for (isize,) {...}

impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize) {
    type Output = Array<A, <<D as Dimension>::Smaller as Dimension>::Smaller>;
    type Repr = [isize; 2];
    fn into_repr(self) -> Option<[isize; 2]> {
        Some([self.0, self.1])
    }
}

// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize, isize, isize) {...}

// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 1] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 2] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 3] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 4] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 5] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 6] {...}

impl<'a, A, D: Dimension> FoldAxes<A, D> for &'a [isize] {
    type Output = Array<A, IxDyn>;
    type Repr = &'a [isize];
    fn into_repr(self) -> Option<&'a [isize]> {
        Some(self)
    }
}

impl<A, D: Dimension> FoldAxes<A, D> for Vec<isize> {
    type Output = Array<A, IxDyn>;
    type Repr = Vec<isize>;
    fn into_repr(self) -> Option<Vec<isize>> {
        Some(self)
    }
}

// same for `usize`, tuples of `usize`, slices of `usize`, vecs of `usize` (casting to `usize` to `isize`)

Then, for example, sum would be

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    fn sum<T>(&self, axes: T) -> <T as FoldAxes<A, D>>::Output
    where
        A: Clone + Add<A, Output=A>,
        T: FoldAxes<A, D>,
    {
        // (or check zero-length here if not using the Option approach)
        let axes_repr: Option<T::Repr> = axes.into_repr();
        if let Some(axes) = axes {
            let axes_slice: &[isize] = axes.as_ref();
            // operate over specified axes...
        } else {
            // operate over all axes...
        }
    }
}

which you could call like

// Sum over all axes.
arr.sum(());
// Sum over axis 1.
arr.sum(1);
// Sum over axes 2 and 3.
arr.sum((2, 3));
// Sum over variable number of axes.
arr.sum(axes_vec);

Note that FoldAxes::Output specifies the entire output type instead of just the output dimension because for the scalar case it would be nice to return A instead of Array0<A>. All of mean/stdev/min/max/variance behave the same way. For the argmin/argmax case, the element type of the output scalar/array is usize instead of A, so you would change the return type slightly:

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    fn argmin<T>(&self, axes: T) -> <T as FoldAxes<usize, D>>::Output
    where
        for<'a> &'a A: PartialOrd<&'a A>,
        T: FoldAxes<usize, D>,
    {
        // ...
    }
}

You could reduce duplication of code in the method implementations by implementing them in terms of a generalized fold operation (like the sum above but taking an initial value and closure) followed by map_inplace. (map_inplace is necessary for mean/stdev but not for min/max/sum.)

Edit:

For what it's worth, if all you want to support is sum() and sum_axis(axis: isize), I think that's fine too and simpler than the generic approach. I can see potential use cases for summing over multiple axes, though.

You could also have sum() (scalar sum) and sum_axes<T>(axes: T) where T: FoldAxes<> (sum over axis/axes) which would be nice because FoldAxes<A, D> could then be simplified to FoldAxes<D> with the output dimension as an associated type. Now that I think about it some more, I like this approach better than combining sum() and sum_axes() together because the scalar case would be simplified to arr.sum() and the FoldAxes trait would be simpler.

Edit 2:

I remembered that computing the standard deviation requires multiple numerical accumulators. So, instead of fold_axes (which would allocate an array containing all the accumulators) followed by map (to map the accumulators to the final result), I'd suggest combining the fold and map like this:

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    fn fold_map_axes<'a, T, B, F, M, O>(&'a self, axes: T, init: B, fold: F, map: M) -> Array<O, <T as FoldAxes<D>>::OutDim>
    where
        A: 'a,
        T: FoldAxes<D>,
        F: FnMut(B, &'a A) -> B,
        M: FnMut(B) -> O,
    {
        // ...
    }
}

@fzyzcjy
Copy link

fzyzcjy commented Oct 22, 2021

Any updates? This is frequently used... Thanks!

@Tootallgavin
Copy link

this is available in ndarray-stats

@failable
Copy link

@Tootallgavin This does not seem to be able to get max elements along any axis.

@nilgoyette
Copy link
Collaborator

You mean an (&self, axis: Axis) version of ndarray-stats quantile methods? Right, this make sense. However, I would argue that they should also be in ndarray-stats, not here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants