Skip to content

Commit 8f3b2b9

Browse files
committed
Experimental divide & conquer in matrix multiply using rayon
1 parent 443062f commit 8f3b2b9

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ version = "0.3.16"
3535
optional = true
3636

3737
[dependencies]
38+
rayon = "0.3.1"
39+
3840
# Use via the `blas` crate feature!
3941
blas-sys = { version = "0.6", optional = true, default-features = false }
4042
openblas-provider = { version = "0.4", optional = true, default-features = false }

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ extern crate rustc_serialize as serialize;
7373
extern crate blas_sys;
7474

7575
extern crate matrixmultiply;
76+
extern crate rayon;
7677

7778
extern crate itertools;
7879
extern crate num as libnum;

src/linalg/impl_linalg.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
use rayon;
910
use libnum::Zero;
1011
use itertools::free::enumerate;
1112

@@ -417,6 +418,8 @@ fn mat_mul_impl<A>(alpha: A,
417418
mat_mul_general(alpha, lhs, rhs, beta, c)
418419
}
419420

421+
const SPLIT: usize = 64;
422+
420423
/// C ← α A B + β C
421424
fn mat_mul_general<A>(alpha: A,
422425
lhs: &ArrayView<A, (Ix, Ix)>,
@@ -425,7 +428,27 @@ fn mat_mul_general<A>(alpha: A,
425428
c: &mut ArrayViewMut<A, (Ix, Ix)>)
426429
where A: LinalgScalar,
427430
{
428-
let ((m, k), (_, n)) = (lhs.dim, rhs.dim);
431+
let ((m, k), (k2, n)) = (lhs.dim, rhs.dim);
432+
433+
debug_assert_eq!(k, k2);
434+
if m > SPLIT {
435+
// [ A0 ] B = [ C0 ]
436+
// [ A1 ] [ C1 ]
437+
let mid = m / 2;
438+
let (a0, a1) = lhs.split_at(Axis(0), mid);
439+
let (mut c0, mut c1) = c.view_mut().split_at(Axis(0), mid);
440+
rayon::join(move || mat_mul_general(alpha, &a0, rhs, beta, &mut c0),
441+
move || mat_mul_general(alpha, &a1, rhs, beta, &mut c1));
442+
return;
443+
} else if n > SPLIT {
444+
// A [ B0 B1 ] = [ C0 C1 ]
445+
let mid = n / 2;
446+
let (b0, b1) = rhs.split_at(Axis(1), mid);
447+
let (mut c0, mut c1) = c.view_mut().split_at(Axis(1), mid);
448+
rayon::join(move || mat_mul_general(alpha, lhs, &b0, beta, &mut c0),
449+
move || mat_mul_general(alpha, lhs, &b1, beta, &mut c1));
450+
return;
451+
}
429452

430453
// common parameters for gemm
431454
let ap = lhs.as_ptr();

0 commit comments

Comments
 (0)