Skip to content

Commit 755e776

Browse files
committed
[mlir][linalg] Vectorize 1D convolution
Differential Revision: https://reviews.llvm.org/D140188
1 parent 1436a92 commit 755e776

File tree

2 files changed

+491
-45
lines changed

2 files changed

+491
-45
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 212 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,10 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
385385
[&](auto op) { return CombiningKind::ADD; })
386386
.Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
387387
.Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
388+
.Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
388389
.Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
389390
.Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
391+
.Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
390392
.Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
391393
.Case<arith::MulIOp, arith::MulFOp>(
392394
[&](auto op) { return CombiningKind::MUL; })
@@ -1796,6 +1798,26 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
17961798
}
17971799

17981800
namespace {
1801+
bool isCastOfBlockArgument(Operation *op) {
1802+
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
1803+
op->getOperand(0).isa<BlockArgument>();
1804+
}
1805+
1806+
bool isSupportedPoolKind(vector::CombiningKind kind) {
1807+
switch (kind) {
1808+
case vector::CombiningKind::ADD:
1809+
case vector::CombiningKind::MAXF:
1810+
case vector::CombiningKind::MAXSI:
1811+
case vector::CombiningKind::MAXUI:
1812+
case vector::CombiningKind::MINF:
1813+
case vector::CombiningKind::MINSI:
1814+
case vector::CombiningKind::MINUI:
1815+
return true;
1816+
default:
1817+
return false;
1818+
}
1819+
}
1820+
17991821
/// Generate a vector implementation for either:
18001822
/// ```
18011823
/// Op def: ( n, w, c, kw, f )
@@ -1838,41 +1860,33 @@ struct Conv1DGenerator
18381860
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
18391861
if (!lhsShapedType || !rhsShapedType || !resShapedType)
18401862
return;
1841-
if (lhsShapedType.getRank() != 3 ||
1842-
(rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
1843-
resShapedType.getRank() != 3)
1863+
// LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC.
1864+
if (lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3)
18441865
return;
18451866

1846-
// Check for reduction `add` preceded by `mul`.
18471867
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
18481868
if (!reduceOp)
18491869
return;
1850-
std::optional<vector::CombiningKind> maybeKind;
1851-
maybeKind = getCombinerOpKind(reduceOp);
1852-
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
1870+
redOp = reduceOp->getName().getIdentifier();
1871+
1872+
if (!setOperKind(reduceOp))
18531873
return;
1854-
// Check for single `mul` predecessor. The `mul` operands must be block
1855-
// arguments or extension of block arguments.
1856-
Operation *mulOp = nullptr;
1857-
for (Value operand : reduceOp->getOperands()) {
1858-
if (operand.isa<BlockArgument>())
1859-
continue;
1860-
if (mulOp)
1861-
return;
1862-
mulOp = operand.getDefiningOp();
1863-
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
1864-
return;
1865-
}
1866-
if (!mulOp)
1874+
auto maybeKind = getCombinerOpKind(reduceOp);
1875+
if (!(maybeKind && (*maybeKind == vector::CombiningKind::ADD ||
1876+
(oper == Pool && isSupportedPoolKind(*maybeKind))))) {
18671877
return;
1868-
for (Value operand : mulOp->getOperands()) {
1869-
if (Operation *def = operand.getDefiningOp()) {
1870-
if (!isa<CastOpInterface>(def))
1871-
return;
1872-
operand = def->getOperand(0);
1873-
}
1874-
if (!operand.isa<BlockArgument>())
1878+
}
1879+
1880+
auto rhsRank = rhsShapedType.getRank();
1881+
switch (oper) {
1882+
case Conv:
1883+
if (rhsRank != 2 && rhsRank!= 3)
1884+
return;
1885+
break;
1886+
case Pool:
1887+
if (rhsRank != 1)
18751888
return;
1889+
break;
18761890
}
18771891
// The op is now known to be valid.
18781892
valid = true;
@@ -1889,38 +1903,70 @@ struct Conv1DGenerator
18891903
/// > 1.
18901904
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
18911905
if (!valid)
1892-
return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv");
1906+
return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
18931907

18941908
int64_t nSize, wSize, cSize, kwSize, fSize;
18951909
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
18961910
switch (conv1DOpOrder) {
18971911
case Conv1DOpOrder::Nwc:
1898-
// kernel{kw, c, f}
1899-
bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
19001912
// out{n, w, f}
1901-
bindShapeDims(resShapedType, nSize, wSize);
1913+
bindShapeDims(resShapedType, nSize, wSize, fSize);
1914+
switch (oper) {
1915+
case Conv:
1916+
// kernel{kw, c, f}
1917+
bindShapeDims(rhsShapedType, kwSize, cSize);
1918+
break;
1919+
case Pool:
1920+
// kernel{kw}
1921+
bindShapeDims(rhsShapedType, kwSize);
1922+
cSize = fSize;
1923+
break;
1924+
}
19021925
lhsShape = {nSize,
19031926
// iw = ow * sw + kw * dw - 1
19041927
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
19051928
// Perform the proper inclusive -> exclusive -> inclusive.
19061929
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
19071930
1,
19081931
cSize};
1909-
rhsShape = {kwSize, cSize, fSize};
1932+
switch (oper) {
1933+
case Conv:
1934+
rhsShape = {kwSize, cSize, fSize};
1935+
break;
1936+
case Pool:
1937+
rhsShape = {kwSize};
1938+
break;
1939+
}
19101940
resShape = {nSize, wSize, fSize};
19111941
break;
19121942
case Conv1DOpOrder::Ncw:
1913-
// kernel{f, c, kw}
1914-
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
19151943
// out{n, f, w}
19161944
bindShapeDims(resShapedType, nSize, fSize, wSize);
1945+
switch (oper) {
1946+
case Conv:
1947+
// kernel{f, c, kw}
1948+
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
1949+
break;
1950+
case Pool:
1951+
// kernel{kw}
1952+
bindShapeDims(rhsShapedType, kwSize);
1953+
cSize = fSize;
1954+
break;
1955+
}
19171956
lhsShape = {nSize, cSize,
19181957
// iw = ow * sw + kw * dw - 1
19191958
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
19201959
// Perform the proper inclusive -> exclusive -> inclusive.
19211960
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
19221961
1};
1923-
rhsShape = {fSize, cSize, kwSize};
1962+
switch (oper) {
1963+
case Conv:
1964+
rhsShape = {fSize, cSize, kwSize};
1965+
break;
1966+
case Pool:
1967+
rhsShape = {kwSize};
1968+
break;
1969+
}
19241970
resShape = {nSize, fSize, wSize};
19251971
break;
19261972
}
@@ -1944,8 +1990,11 @@ struct Conv1DGenerator
19441990
Value lhs = rewriter.create<vector::TransferReadOp>(
19451991
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
19461992
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
1947-
Value rhs = rewriter.create<vector::TransferReadOp>(
1948-
loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
1993+
// This is needed only for Conv.
1994+
Value rhs = nullptr;
1995+
if (oper == Conv)
1996+
rhs = rewriter.create<vector::TransferReadOp>(
1997+
loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
19491998
// Read res slice of size {n, w, f} @ [0, 0, 0].
19501999
Value res = rewriter.create<vector::TransferReadOp>(
19512000
loc, resType, resShaped, ValueRange{zero, zero, zero});
@@ -1964,7 +2013,10 @@ struct Conv1DGenerator
19642013
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
19652014
// fcw -> wcf
19662015
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
1967-
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
2016+
2017+
// This is needed only for Conv.
2018+
if (oper == Conv)
2019+
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
19682020
// nfw -> nwf
19692021
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
19702022
res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
@@ -1988,10 +2040,12 @@ struct Conv1DGenerator
19882040
}
19892041
}
19902042
// Extract rhs slice of size {c, f} @ [kw].
1991-
for (int64_t kw = 0; kw < kwSize; ++kw) {
1992-
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
1993-
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
1994-
}
2043+
// Do not do for pooling.
2044+
if (oper == Conv)
2045+
for (int64_t kw = 0; kw < kwSize; ++kw) {
2046+
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
2047+
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
2048+
}
19952049
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
19962050
for (int64_t w = 0; w < wSize; w += wSizeStep) {
19972051
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
@@ -2005,11 +2059,21 @@ struct Conv1DGenerator
20052059
return kw * (wSize / wSizeStep) + w;
20062060
};
20072061

2008-
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
2062+
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
2063+
// perform simple arith operation for pooling
20092064
for (int64_t kw = 0; kw < kwSize; ++kw) {
20102065
for (int64_t w = 0; w < wSize; w += wSizeStep) {
2011-
resVals[w] = conv1dSliceAsContraction(
2012-
rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
2066+
switch (oper) {
2067+
case Conv:
2068+
resVals[w] = conv1dSliceAsContraction(rewriter, loc,
2069+
lhsVals[linearIndex(kw, w)],
2070+
rhsVals[kw], resVals[w]);
2071+
break;
2072+
case Pool:
2073+
resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
2074+
resVals[w]);
2075+
break;
2076+
}
20132077
}
20142078
}
20152079

@@ -2060,6 +2124,16 @@ struct Conv1DGenerator
20602124
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
20612125
}
20622126

2127+
// Create a reduction: lhs{n, w, c} -> res{n, w, c}
2128+
Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
2129+
Value res) {
2130+
if (isPoolExt)
2131+
lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
2132+
return rewriter
2133+
.create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
2134+
->getResult(0);
2135+
}
2136+
20632137
/// Generate a vector implementation for:
20642138
/// ```
20652139
/// Op def: ( n, w, c, kw)
@@ -2236,6 +2310,7 @@ struct Conv1DGenerator
22362310
/*rhsIndex*/ {kw, c, f},
22372311
/*resIndex*/ {n, w, f}}))
22382312
return conv(Conv1DOpOrder::Nwc);
2313+
22392314
return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
22402315
}
22412316

@@ -2256,6 +2331,41 @@ struct Conv1DGenerator
22562331
return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
22572332
}
22582333

2334+
/// Entry point that transposes into the common form:
2335+
/// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
2336+
FailureOr<Operation *> generateNwcPooling() {
2337+
AffineExpr n, w, c, kw;
2338+
bindDims(ctx, n, w, c, kw);
2339+
if (!iters({Par(), Par(), Par(), Red()}))
2340+
return rewriter.notifyMatchFailure(op,
2341+
"failed to match pooling 3-par 1-red");
2342+
2343+
// No transposition needed.
2344+
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
2345+
/*rhsIndex*/ {kw},
2346+
/*resIndex*/ {n, w, c}}))
2347+
return conv(Conv1DOpOrder::Nwc);
2348+
2349+
return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
2350+
}
2351+
2352+
/// Entry point that transposes into the common form:
2353+
/// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
2354+
FailureOr<Operation *> generateNcwPooling() {
2355+
AffineExpr n, w, c, kw;
2356+
bindDims(ctx, n, c, w, kw);
2357+
if (!iters({Par(), Par(), Par(), Red()}))
2358+
return rewriter.notifyMatchFailure(op,
2359+
"failed to match pooling 3-par 1-red");
2360+
2361+
if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
2362+
/*rhsIndex*/ {kw},
2363+
/*resIndex*/ {n, c, w}}))
2364+
return conv(Conv1DOpOrder::Ncw);
2365+
2366+
return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
2367+
}
2368+
22592369
/// Entry point that transposes into the common form:
22602370
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
22612371
FailureOr<Operation *> generateDilatedConv() {
@@ -2275,10 +2385,61 @@ struct Conv1DGenerator
22752385
}
22762386

22772387
private:
2388+
enum OperKind { Conv, Pool };
22782389
bool valid = false;
2390+
OperKind oper = Conv;
2391+
StringAttr redOp;
2392+
StringAttr poolExtOp;
2393+
bool isPoolExt = false;
22792394
int strideW, dilationW;
22802395
Value lhsShaped, rhsShaped, resShaped;
22812396
ShapedType lhsShapedType, rhsShapedType, resShapedType;
2397+
2398+
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
2399+
// Returns true iff it is a valid conv/pooling op.
2400+
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2401+
// + yield) and rhs is not used) then it is the body of a pooling
2402+
// If conv, check for single `mul` predecessor. The `mul` operands must be
2403+
// block arguments or extension of block arguments.
2404+
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2405+
// must be block arguments or extension of block arguments.
2406+
bool setOperKind(Operation *reduceOp) {
2407+
int numBlockArguments =
2408+
llvm::count_if(reduceOp->getOperands(),
2409+
[](Value v) { return v.isa<BlockArgument>(); });
2410+
switch (numBlockArguments) {
2411+
case 1: {
2412+
// Will be convolution if feeder is a MulOp.
2413+
// Otherwise, if it can be pooling.
2414+
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
2415+
return !v.isa<BlockArgument>();
2416+
});
2417+
Operation *feedOp = (*feedValIt).getDefiningOp();
2418+
if (isCastOfBlockArgument(feedOp)) {
2419+
oper = Pool;
2420+
isPoolExt = true;
2421+
poolExtOp = feedOp->getName().getIdentifier();
2422+
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
2423+
llvm::all_of(feedOp->getOperands(), [](Value v) {
2424+
if (v.isa<BlockArgument>())
2425+
return true;
2426+
if (Operation *op = v.getDefiningOp())
2427+
return isCastOfBlockArgument(op);
2428+
return false;
2429+
}))) {
2430+
return false;
2431+
}
2432+
return true;
2433+
}
2434+
case 2:
2435+
// Must be pooling
2436+
oper = Pool;
2437+
isPoolExt = false;
2438+
return true;
2439+
default:
2440+
return false;
2441+
}
2442+
}
22822443
};
22832444
} // namespace
22842445

@@ -2299,6 +2460,12 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
22992460
if (succeeded(res))
23002461
return res;
23012462
res = e.generateNcwConv();
2463+
if (succeeded(res))
2464+
return res;
2465+
res = e.generateNwcPooling();
2466+
if (succeeded(res))
2467+
return res;
2468+
res = e.generateNcwPooling();
23022469
if (succeeded(res))
23032470
return res;
23042471
return e.generateDilatedConv();

0 commit comments

Comments
 (0)