Skip to content

Commit 2fdfab4

Browse files
authored
Merge pull request #700 from adenchfi/master
Kronecker Product addition to stdlib_linalg
2 parents 91387a0 + a6584f7 commit 2fdfab4

File tree

6 files changed

+145
-1
lines changed

6 files changed

+145
-1
lines changed

doc/specs/stdlib_linalg.md

+31
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,37 @@ Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vect
160160
{!example/linalg/example_outer_product.f90!}
161161
```
162162

163+
## `kronecker_product` - Computes the Kronecker product of two rank-2 arrays
164+
165+
### Status
166+
167+
Experimental
168+
169+
### Description
170+
171+
Computes the Kronecker product of two rank-2 arrays
172+
173+
### Syntax
174+
175+
`C = [[stdlib_linalg(module):kronecker_product(interface)]](A, B)`
176+
177+
### Arguments
178+
179+
`A`: Shall be a rank-2 array with dimensions M1, N1
180+
181+
`B`: Shall be a rank-2 array with dimensions M2, N2
182+
183+
### Return value
184+
185+
Returns a rank-2 array equal to `A \otimes B`. The shape of the returned array is `[M1*M2, N1*N2]`.
186+
187+
### Example
188+
189+
```fortran
190+
{!example/linalg/example_kronecker_product.f90!}
191+
```
192+
193+
163194
## `cross_product` - Computes the cross product of two vectors
164195

165196
### Status
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
program example_kronecker_product
2+
use stdlib_linalg, only: kronecker_product
3+
implicit none
4+
integer, parameter :: m1 = 1, n1 = 2, m2 = 2, n2 = 3
5+
integer :: i, j
6+
real :: A(m1, n1), B(m2,n2)
7+
real, allocatable :: C(:,:)
8+
9+
do j = 1, n1
10+
do i = 1, m1
11+
A(i,j) = i*j ! A = [1, 2]
12+
end do
13+
end do
14+
15+
do j = 1, n2
16+
do i = 1, m2 ! B = [ 1, 2, 3 ]
17+
B(i,j) = i*j ! [ 2, 4, 6 ]
18+
end do
19+
end do
20+
21+
C = kronecker_product(A, B)
22+
! C = [ a(1,1) * B(:,:) | a(1,2) * B(:,:) ]
23+
! or in other words,
24+
! C = [ 1.00 2.00 3.00 2.00 4.00 6.00 ]
25+
! [ 2.00 4.00 6.00 4.00 8.00 12.00 ]
26+
end program example_kronecker_product

src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(fppFiles
2222
stdlib_linalg.fypp
2323
stdlib_linalg_diag.fypp
2424
stdlib_linalg_outer_product.fypp
25+
stdlib_linalg_kronecker.fypp
2526
stdlib_linalg_cross_product.fypp
2627
stdlib_optval.fypp
2728
stdlib_selection.fypp

src/stdlib_linalg.fypp

+15
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module stdlib_linalg
1414
public :: eye
1515
public :: trace
1616
public :: outer_product
17+
public :: kronecker_product
1718
public :: cross_product
1819
public :: is_square
1920
public :: is_diagonal
@@ -93,6 +94,20 @@ module stdlib_linalg
9394
#:endfor
9495
end interface outer_product
9596

97+
interface kronecker_product
98+
!! version: experimental
99+
!!
100+
!! Computes the Kronecker product of two arrays of size M1xN1, and of M2xN2, returning an (M1*M2)x(N1*N2) array
101+
!! ([Specification](../page/specs/stdlib_linalg.html#
102+
!! kronecker_product-computes-the-kronecker-product-of-two-matrices))
103+
#:for k1, t1 in RCI_KINDS_TYPES
104+
pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C)
105+
${t1}$, intent(in) :: A(:,:), B(:,:)
106+
${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2))
107+
end function kronecker_product_${t1[0]}$${k1}$
108+
#:endfor
109+
end interface kronecker_product
110+
96111

97112
! Cross product (of two vectors)
98113
interface cross_product

src/stdlib_linalg_kronecker.fypp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#:include "common.fypp"
2+
#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES
3+
submodule (stdlib_linalg) stdlib_linalg_kronecker
4+
5+
implicit none
6+
7+
contains
8+
9+
#:for k1, t1 in RCI_KINDS_TYPES
10+
pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C)
11+
${t1}$, intent(in) :: A(:,:), B(:,:)
12+
${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2))
13+
integer :: m1, n1, maxM1, maxN1, maxM2, maxN2
14+
15+
maxM1 = size(A, dim=1)
16+
maxN1 = size(A, dim=2)
17+
maxM2 = size(B, dim=1)
18+
maxN2 = size(B, dim=2)
19+
20+
21+
do n1 = 1, maxN1
22+
do m1 = 1, maxM1
23+
! We use the Wikipedia convention for ordering of the matrix elements
24+
! https://en.wikipedia.org/wiki/Kronecker_product
25+
C((m1-1)*maxM2+1:m1*maxM2, (n1-1)*maxN2+1:n1*maxN2) = A(m1, n1) * B(:,:)
26+
end do
27+
end do
28+
end function kronecker_product_${t1[0]}$${k1}$
29+
#:endfor
30+
end submodule stdlib_linalg_kronecker

test/linalg/test_linalg.fypp

+42-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#:include "common.fypp"
2+
#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES
23

34
module test_linalg
45
use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test
56
use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64
6-
use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product
7+
use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product, kronecker_product
78

89
implicit none
910

@@ -48,6 +49,9 @@ contains
4849
new_unittest("trace_int16", test_trace_int16), &
4950
new_unittest("trace_int32", test_trace_int32), &
5051
new_unittest("trace_int64", test_trace_int64), &
52+
#:for k1, t1 in RCI_KINDS_TYPES
53+
new_unittest("kronecker_product_${t1[0]}$${k1}$", test_kronecker_product_${t1[0]}$${k1}$), &
54+
#:endfor
5155
new_unittest("outer_product_rsp", test_outer_product_rsp), &
5256
new_unittest("outer_product_rdp", test_outer_product_rdp), &
5357
new_unittest("outer_product_rqp", test_outer_product_rqp), &
@@ -554,6 +558,43 @@ contains
554558
end subroutine test_trace_int64
555559

556560

561+
562+
#:for k1, t1 in RCI_KINDS_TYPES
563+
subroutine test_kronecker_product_${t1[0]}$${k1}$(error)
564+
!> Error handling
565+
type(error_type), allocatable, intent(out) :: error
566+
integer, parameter :: m1 = 1, n1 = 2, m2 = 2, n2 = 3
567+
${t1}$, dimension(m1*m2,n1*n2), parameter :: expected &
568+
= transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4,8,12], [m2*n2, m1*n1]))
569+
${t1}$, parameter :: tol = 1.e-6
570+
571+
${t1}$ :: A(m1,n1), B(m2,n2)
572+
${t1}$ :: C(m1*m2,n1*n2), diff(m1*m2,n1*n2)
573+
574+
integer :: i,j
575+
576+
do j = 1, n1
577+
do i = 1, m1
578+
A(i,j) = i*j ! A = [1, 2]
579+
end do
580+
end do
581+
582+
do j = 1, n2
583+
do i = 1, m2
584+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
585+
end do
586+
end do
587+
588+
C = kronecker_product(A,B)
589+
590+
diff = C - expected
591+
592+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
593+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
594+
595+
end subroutine test_kronecker_product_${t1[0]}$${k1}$
596+
#:endfor
597+
557598
subroutine test_outer_product_rsp(error)
558599
!> Error handling
559600
type(error_type), allocatable, intent(out) :: error

0 commit comments

Comments
 (0)