@@ -385,8 +385,10 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
385
385
[&](auto op) { return CombiningKind::ADD; })
386
386
.Case <arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
387
387
.Case <arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
388
+ .Case <arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
388
389
.Case <arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
389
390
.Case <arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
391
+ .Case <arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
390
392
.Case <arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
391
393
.Case <arith::MulIOp, arith::MulFOp>(
392
394
[&](auto op) { return CombiningKind::MUL; })
@@ -1796,6 +1798,26 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
1796
1798
}
1797
1799
1798
1800
namespace {
1801
+ bool isCastOfBlockArgument (Operation *op) {
1802
+ return isa<CastOpInterface>(op) && op->getNumOperands () == 1 &&
1803
+ op->getOperand (0 ).isa <BlockArgument>();
1804
+ }
1805
+
1806
+ bool isSupportedPoolKind (vector::CombiningKind kind) {
1807
+ switch (kind) {
1808
+ case vector::CombiningKind::ADD:
1809
+ case vector::CombiningKind::MAXF:
1810
+ case vector::CombiningKind::MAXSI:
1811
+ case vector::CombiningKind::MAXUI:
1812
+ case vector::CombiningKind::MINF:
1813
+ case vector::CombiningKind::MINSI:
1814
+ case vector::CombiningKind::MINUI:
1815
+ return true ;
1816
+ default :
1817
+ return false ;
1818
+ }
1819
+ }
1820
+
1799
1821
// / Generate a vector implementation for either:
1800
1822
// / ```
1801
1823
// / Op def: ( n, w, c, kw, f )
@@ -1838,41 +1860,33 @@ struct Conv1DGenerator
1838
1860
resShapedType = resShaped.getType ().dyn_cast <ShapedType>();
1839
1861
if (!lhsShapedType || !rhsShapedType || !resShapedType)
1840
1862
return ;
1841
- if (lhsShapedType.getRank () != 3 ||
1842
- (rhsShapedType.getRank () != 2 && rhsShapedType.getRank () != 3 ) ||
1843
- resShapedType.getRank () != 3 )
1863
+ // LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC.
1864
+ if (lhsShapedType.getRank () != 3 || resShapedType.getRank () != 3 )
1844
1865
return ;
1845
1866
1846
- // Check for reduction `add` preceded by `mul`.
1847
1867
Operation *reduceOp = matchLinalgReduction (linalgOp.getDpsInitOperand (0 ));
1848
1868
if (!reduceOp)
1849
1869
return ;
1850
- std::optional<vector::CombiningKind> maybeKind ;
1851
- maybeKind = getCombinerOpKind (reduceOp);
1852
- if (!maybeKind || *maybeKind != vector::CombiningKind::ADD )
1870
+ redOp = reduceOp-> getName (). getIdentifier () ;
1871
+
1872
+ if (!setOperKind (reduceOp) )
1853
1873
return ;
1854
- // Check for single `mul` predecessor. The `mul` operands must be block
1855
- // arguments or extension of block arguments.
1856
- Operation *mulOp = nullptr ;
1857
- for (Value operand : reduceOp->getOperands ()) {
1858
- if (operand.isa <BlockArgument>())
1859
- continue ;
1860
- if (mulOp)
1861
- return ;
1862
- mulOp = operand.getDefiningOp ();
1863
- if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
1864
- return ;
1865
- }
1866
- if (!mulOp)
1874
+ auto maybeKind = getCombinerOpKind (reduceOp);
1875
+ if (!(maybeKind && (*maybeKind == vector::CombiningKind::ADD ||
1876
+ (oper == Pool && isSupportedPoolKind (*maybeKind))))) {
1867
1877
return ;
1868
- for (Value operand : mulOp->getOperands ()) {
1869
- if (Operation *def = operand.getDefiningOp ()) {
1870
- if (!isa<CastOpInterface>(def))
1871
- return ;
1872
- operand = def->getOperand (0 );
1873
- }
1874
- if (!operand.isa <BlockArgument>())
1878
+ }
1879
+
1880
+ auto rhsRank = rhsShapedType.getRank ();
1881
+ switch (oper) {
1882
+ case Conv:
1883
+ if (rhsRank != 2 && rhsRank!= 3 )
1884
+ return ;
1885
+ break ;
1886
+ case Pool:
1887
+ if (rhsRank != 1 )
1875
1888
return ;
1889
+ break ;
1876
1890
}
1877
1891
// The op is now known to be valid.
1878
1892
valid = true ;
@@ -1889,38 +1903,70 @@ struct Conv1DGenerator
1889
1903
// / > 1.
1890
1904
FailureOr<Operation *> conv (Conv1DOpOrder conv1DOpOrder) {
1891
1905
if (!valid)
1892
- return rewriter.notifyMatchFailure (op, " unvectorizable 1-D conv" );
1906
+ return rewriter.notifyMatchFailure (op, " unvectorizable 1-D conv/pool " );
1893
1907
1894
1908
int64_t nSize, wSize, cSize, kwSize, fSize ;
1895
1909
SmallVector<int64_t , 3 > lhsShape, rhsShape, resShape;
1896
1910
switch (conv1DOpOrder) {
1897
1911
case Conv1DOpOrder::Nwc:
1898
- // kernel{kw, c, f}
1899
- bindShapeDims (rhsShapedType, kwSize, cSize, fSize );
1900
1912
// out{n, w, f}
1901
- bindShapeDims (resShapedType, nSize, wSize);
1913
+ bindShapeDims (resShapedType, nSize, wSize, fSize );
1914
+ switch (oper) {
1915
+ case Conv:
1916
+ // kernel{kw, c, f}
1917
+ bindShapeDims (rhsShapedType, kwSize, cSize);
1918
+ break ;
1919
+ case Pool:
1920
+ // kernel{kw}
1921
+ bindShapeDims (rhsShapedType, kwSize);
1922
+ cSize = fSize ;
1923
+ break ;
1924
+ }
1902
1925
lhsShape = {nSize,
1903
1926
// iw = ow * sw + kw * dw - 1
1904
1927
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1905
1928
// Perform the proper inclusive -> exclusive -> inclusive.
1906
1929
((wSize - 1 ) * strideW + 1 ) + ((kwSize - 1 ) * dilationW + 1 ) -
1907
1930
1 ,
1908
1931
cSize};
1909
- rhsShape = {kwSize, cSize, fSize };
1932
+ switch (oper) {
1933
+ case Conv:
1934
+ rhsShape = {kwSize, cSize, fSize };
1935
+ break ;
1936
+ case Pool:
1937
+ rhsShape = {kwSize};
1938
+ break ;
1939
+ }
1910
1940
resShape = {nSize, wSize, fSize };
1911
1941
break ;
1912
1942
case Conv1DOpOrder::Ncw:
1913
- // kernel{f, c, kw}
1914
- bindShapeDims (rhsShapedType, fSize , cSize, kwSize);
1915
1943
// out{n, f, w}
1916
1944
bindShapeDims (resShapedType, nSize, fSize , wSize);
1945
+ switch (oper) {
1946
+ case Conv:
1947
+ // kernel{f, c, kw}
1948
+ bindShapeDims (rhsShapedType, fSize , cSize, kwSize);
1949
+ break ;
1950
+ case Pool:
1951
+ // kernel{kw}
1952
+ bindShapeDims (rhsShapedType, kwSize);
1953
+ cSize = fSize ;
1954
+ break ;
1955
+ }
1917
1956
lhsShape = {nSize, cSize,
1918
1957
// iw = ow * sw + kw * dw - 1
1919
1958
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1920
1959
// Perform the proper inclusive -> exclusive -> inclusive.
1921
1960
((wSize - 1 ) * strideW + 1 ) + ((kwSize - 1 ) * dilationW + 1 ) -
1922
1961
1 };
1923
- rhsShape = {fSize , cSize, kwSize};
1962
+ switch (oper) {
1963
+ case Conv:
1964
+ rhsShape = {fSize , cSize, kwSize};
1965
+ break ;
1966
+ case Pool:
1967
+ rhsShape = {kwSize};
1968
+ break ;
1969
+ }
1924
1970
resShape = {nSize, fSize , wSize};
1925
1971
break ;
1926
1972
}
@@ -1944,8 +1990,11 @@ struct Conv1DGenerator
1944
1990
Value lhs = rewriter.create <vector::TransferReadOp>(
1945
1991
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
1946
1992
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
1947
- Value rhs = rewriter.create <vector::TransferReadOp>(
1948
- loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
1993
+ // This is needed only for Conv.
1994
+ Value rhs = nullptr ;
1995
+ if (oper == Conv)
1996
+ rhs = rewriter.create <vector::TransferReadOp>(
1997
+ loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
1949
1998
// Read res slice of size {n, w, f} @ [0, 0, 0].
1950
1999
Value res = rewriter.create <vector::TransferReadOp>(
1951
2000
loc, resType, resShaped, ValueRange{zero, zero, zero});
@@ -1964,7 +2013,10 @@ struct Conv1DGenerator
1964
2013
lhs = rewriter.create <vector::TransposeOp>(loc, lhs, permLhs);
1965
2014
// fcw -> wcf
1966
2015
static constexpr std::array<int64_t , 3 > permRhs = {2 , 1 , 0 };
1967
- rhs = rewriter.create <vector::TransposeOp>(loc, rhs, permRhs);
2016
+
2017
+ // This is needed only for Conv.
2018
+ if (oper == Conv)
2019
+ rhs = rewriter.create <vector::TransposeOp>(loc, rhs, permRhs);
1968
2020
// nfw -> nwf
1969
2021
static constexpr std::array<int64_t , 3 > permRes = {0 , 2 , 1 };
1970
2022
res = rewriter.create <vector::TransposeOp>(loc, res, permRes);
@@ -1988,10 +2040,12 @@ struct Conv1DGenerator
1988
2040
}
1989
2041
}
1990
2042
// Extract rhs slice of size {c, f} @ [kw].
1991
- for (int64_t kw = 0 ; kw < kwSize; ++kw) {
1992
- rhsVals.push_back (rewriter.create <vector::ExtractOp>(
1993
- loc, rhs, /* offsets=*/ ArrayRef<int64_t >{kw}));
1994
- }
2043
+ // Do not do for pooling.
2044
+ if (oper == Conv)
2045
+ for (int64_t kw = 0 ; kw < kwSize; ++kw) {
2046
+ rhsVals.push_back (rewriter.create <vector::ExtractOp>(
2047
+ loc, rhs, /* offsets=*/ ArrayRef<int64_t >{kw}));
2048
+ }
1995
2049
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
1996
2050
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
1997
2051
resVals.push_back (rewriter.create <vector::ExtractStridedSliceOp>(
@@ -2005,11 +2059,21 @@ struct Conv1DGenerator
2005
2059
return kw * (wSize / wSizeStep) + w;
2006
2060
};
2007
2061
2008
- // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
2062
+ // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
2063
+ // perform simple arith operation for pooling
2009
2064
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
2010
2065
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2011
- resVals[w] = conv1dSliceAsContraction (
2012
- rewriter, loc, lhsVals[linearIndex (kw, w)], rhsVals[kw], resVals[w]);
2066
+ switch (oper) {
2067
+ case Conv:
2068
+ resVals[w] = conv1dSliceAsContraction (rewriter, loc,
2069
+ lhsVals[linearIndex (kw, w)],
2070
+ rhsVals[kw], resVals[w]);
2071
+ break ;
2072
+ case Pool:
2073
+ resVals[w] = pool1dSlice (rewriter, loc, lhsVals[linearIndex (kw, w)],
2074
+ resVals[w]);
2075
+ break ;
2076
+ }
2013
2077
}
2014
2078
}
2015
2079
@@ -2060,6 +2124,16 @@ struct Conv1DGenerator
2060
2124
/* iteratorTypes=*/ ArrayRef<vector::IteratorType>{par, par, par, red});
2061
2125
}
2062
2126
2127
+ // Create a reduction: lhs{n, w, c} -> res{n, w, c}
2128
+ Value pool1dSlice (RewriterBase &rewriter, Location loc, Value lhs,
2129
+ Value res) {
2130
+ if (isPoolExt)
2131
+ lhs = rewriter.create (loc, poolExtOp, lhs, res.getType ())->getResult (0 );
2132
+ return rewriter
2133
+ .create (loc, redOp, ArrayRef<Value>{lhs, res}, res.getType ())
2134
+ ->getResult (0 );
2135
+ }
2136
+
2063
2137
// / Generate a vector implementation for:
2064
2138
// / ```
2065
2139
// / Op def: ( n, w, c, kw)
@@ -2236,6 +2310,7 @@ struct Conv1DGenerator
2236
2310
/* rhsIndex*/ {kw, c, f},
2237
2311
/* resIndex*/ {n, w, f}}))
2238
2312
return conv (Conv1DOpOrder::Nwc);
2313
+
2239
2314
return rewriter.notifyMatchFailure (op, " not a conv::Nwc layout" );
2240
2315
}
2241
2316
@@ -2256,6 +2331,41 @@ struct Conv1DGenerator
2256
2331
return rewriter.notifyMatchFailure (op, " not a conv::Ncw layout" );
2257
2332
}
2258
2333
2334
+ // / Entry point that transposes into the common form:
2335
+ // / {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
2336
+ FailureOr<Operation *> generateNwcPooling () {
2337
+ AffineExpr n, w, c, kw;
2338
+ bindDims (ctx, n, w, c, kw);
2339
+ if (!iters ({Par (), Par (), Par (), Red ()}))
2340
+ return rewriter.notifyMatchFailure (op,
2341
+ " failed to match pooling 3-par 1-red" );
2342
+
2343
+ // No transposition needed.
2344
+ if (layout ({/* lhsIndex*/ {n, strideW * w + dilationW * kw, c},
2345
+ /* rhsIndex*/ {kw},
2346
+ /* resIndex*/ {n, w, c}}))
2347
+ return conv (Conv1DOpOrder::Nwc);
2348
+
2349
+ return rewriter.notifyMatchFailure (op, " not a pooling::Nwc layout" );
2350
+ }
2351
+
2352
+ // / Entry point that transposes into the common form:
2353
+ // / {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
2354
+ FailureOr<Operation *> generateNcwPooling () {
2355
+ AffineExpr n, w, c, kw;
2356
+ bindDims (ctx, n, c, w, kw);
2357
+ if (!iters ({Par (), Par (), Par (), Red ()}))
2358
+ return rewriter.notifyMatchFailure (op,
2359
+ " failed to match pooling 3-par 1-red" );
2360
+
2361
+ if (layout ({/* lhsIndex*/ {n, c, strideW * w + dilationW * kw},
2362
+ /* rhsIndex*/ {kw},
2363
+ /* resIndex*/ {n, c, w}}))
2364
+ return conv (Conv1DOpOrder::Ncw);
2365
+
2366
+ return rewriter.notifyMatchFailure (op, " not a pooling::Ncw layout" );
2367
+ }
2368
+
2259
2369
// / Entry point that transposes into the common form:
2260
2370
// / {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2261
2371
FailureOr<Operation *> generateDilatedConv () {
@@ -2275,10 +2385,61 @@ struct Conv1DGenerator
2275
2385
}
2276
2386
2277
2387
private:
2388
+ enum OperKind { Conv, Pool };
2278
2389
bool valid = false ;
2390
+ OperKind oper = Conv;
2391
+ StringAttr redOp;
2392
+ StringAttr poolExtOp;
2393
+ bool isPoolExt = false ;
2279
2394
int strideW, dilationW;
2280
2395
Value lhsShaped, rhsShaped, resShaped;
2281
2396
ShapedType lhsShapedType, rhsShapedType, resShapedType;
2397
+
2398
+ // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
2399
+ // Returns true iff it is a valid conv/pooling op.
2400
+ // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2401
+ // + yield) and rhs is not used) then it is the body of a pooling
2402
+ // If conv, check for single `mul` predecessor. The `mul` operands must be
2403
+ // block arguments or extension of block arguments.
2404
+ // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2405
+ // must be block arguments or extension of block arguments.
2406
+ bool setOperKind (Operation *reduceOp) {
2407
+ int numBlockArguments =
2408
+ llvm::count_if (reduceOp->getOperands (),
2409
+ [](Value v) { return v.isa <BlockArgument>(); });
2410
+ switch (numBlockArguments) {
2411
+ case 1 : {
2412
+ // Will be convolution if feeder is a MulOp.
2413
+ // Otherwise, if it can be pooling.
2414
+ auto feedValIt = llvm::find_if (reduceOp->getOperands (), [](Value v) {
2415
+ return !v.isa <BlockArgument>();
2416
+ });
2417
+ Operation *feedOp = (*feedValIt).getDefiningOp ();
2418
+ if (isCastOfBlockArgument (feedOp)) {
2419
+ oper = Pool;
2420
+ isPoolExt = true ;
2421
+ poolExtOp = feedOp->getName ().getIdentifier ();
2422
+ } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
2423
+ llvm::all_of (feedOp->getOperands (), [](Value v) {
2424
+ if (v.isa <BlockArgument>())
2425
+ return true ;
2426
+ if (Operation *op = v.getDefiningOp ())
2427
+ return isCastOfBlockArgument (op);
2428
+ return false ;
2429
+ }))) {
2430
+ return false ;
2431
+ }
2432
+ return true ;
2433
+ }
2434
+ case 2 :
2435
+ // Must be pooling
2436
+ oper = Pool;
2437
+ isPoolExt = false ;
2438
+ return true ;
2439
+ default :
2440
+ return false ;
2441
+ }
2442
+ }
2282
2443
};
2283
2444
} // namespace
2284
2445
@@ -2299,6 +2460,12 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
2299
2460
if (succeeded (res))
2300
2461
return res;
2301
2462
res = e.generateNcwConv ();
2463
+ if (succeeded (res))
2464
+ return res;
2465
+ res = e.generateNwcPooling ();
2466
+ if (succeeded (res))
2467
+ return res;
2468
+ res = e.generateNcwPooling ();
2302
2469
if (succeeded (res))
2303
2470
return res;
2304
2471
return e.generateDilatedConv ();
0 commit comments