6
6
// option. This file may not be copied, modified, or distributed
7
7
// except according to those terms.
8
8
9
+ use rayon;
9
10
use libnum:: Zero ;
10
11
use itertools:: free:: enumerate;
11
12
@@ -417,6 +418,8 @@ fn mat_mul_impl<A>(alpha: A,
417
418
mat_mul_general ( alpha, lhs, rhs, beta, c)
418
419
}
419
420
421
+ const SPLIT : usize = 64 ;
422
+
420
423
/// C ← α A B + β C
421
424
fn mat_mul_general < A > ( alpha : A ,
422
425
lhs : & ArrayView < A , ( Ix , Ix ) > ,
@@ -425,7 +428,27 @@ fn mat_mul_general<A>(alpha: A,
425
428
c : & mut ArrayViewMut < A , ( Ix , Ix ) > )
426
429
where A : LinalgScalar ,
427
430
{
428
- let ( ( m, k) , ( _, n) ) = ( lhs. dim , rhs. dim ) ;
431
+ let ( ( m, k) , ( k2, n) ) = ( lhs. dim , rhs. dim ) ;
432
+
433
+ debug_assert_eq ! ( k, k2) ;
434
+ if m > SPLIT {
435
+ // [ A0 ] B = [ C0 ]
436
+ // [ A1 ] [ C1 ]
437
+ let mid = m / 2 ;
438
+ let ( a0, a1) = lhs. split_at ( Axis ( 0 ) , mid) ;
439
+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 0 ) , mid) ;
440
+ rayon:: join ( move || mat_mul_general ( alpha, & a0, rhs, beta, & mut c0) ,
441
+ move || mat_mul_general ( alpha, & a1, rhs, beta, & mut c1) ) ;
442
+ return ;
443
+ } else if n > SPLIT {
444
+ // A [ B0 B1 ] = [ C0 C1 ]
445
+ let mid = n / 2 ;
446
+ let ( b0, b1) = rhs. split_at ( Axis ( 1 ) , mid) ;
447
+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 1 ) , mid) ;
448
+ rayon:: join ( move || mat_mul_general ( alpha, lhs, & b0, beta, & mut c0) ,
449
+ move || mat_mul_general ( alpha, lhs, & b1, beta, & mut c1) ) ;
450
+ return ;
451
+ }
429
452
430
453
// common parameters for gemm
431
454
let ap = lhs. as_ptr ( ) ;
0 commit comments