[mlir][Vector] Move vector.mask
canonicalization to folder (#140324)
This MR moves the canonicalization that elides empty `vector.mask` ops to folders.
This commit is contained in:
parent
12c62ebcb2
commit
d6f394e141
@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [
|
||||
Location loc);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
|
@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Folds vector.mask ops with an all-true mask.
|
||||
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
MaskFormat maskFormat = getMaskFormat(getMask());
|
||||
if (isEmpty())
|
||||
/// Folds empty `vector.mask` with no passthru operand and with or without
|
||||
/// return values. For example:
|
||||
///
|
||||
/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
|
||||
/// vector<8xi1> -> vector<8xf32>
|
||||
/// %1 = user_op %0 : vector<8xf32>
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %0 = user_op %a : vector<8xf32>
|
||||
///
|
||||
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (!maskOp.isEmpty() || maskOp.hasPassthru())
|
||||
return failure();
|
||||
|
||||
Block *block = maskOp.getMaskBlock();
|
||||
auto terminator = cast<vector::YieldOp>(block->front());
|
||||
if (terminator.getNumOperands() == 0) {
|
||||
// `vector.mask` has no results, just remove the `vector.mask`.
|
||||
return success();
|
||||
}
|
||||
|
||||
// `vector.mask` has results, propagate the results.
|
||||
llvm::append_range(results, terminator.getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
|
||||
return success();
|
||||
|
||||
MaskFormat maskFormat = getMaskFormat(getMask());
|
||||
if (maskFormat != MaskFormat::AllTrue)
|
||||
return failure();
|
||||
|
||||
@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Elides empty vector.mask operations with or without return values. Propagates
|
||||
// the yielded values by the vector.yield terminator, if any, or erases the op,
|
||||
// otherwise.
|
||||
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(MaskOp maskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
|
||||
if (maskingOp.getMaskableOp())
|
||||
return failure();
|
||||
|
||||
if (!maskOp.isEmpty())
|
||||
return failure();
|
||||
|
||||
Block *block = maskOp.getMaskBlock();
|
||||
auto terminator = cast<vector::YieldOp>(block->front());
|
||||
if (terminator.getNumOperands() == 0)
|
||||
rewriter.eraseOp(maskOp);
|
||||
else
|
||||
rewriter.replaceOp(maskOp, terminator.getOperands());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ElideEmptyMaskOp>(context);
|
||||
}
|
||||
|
||||
// MaskingOpInterface definitions.
|
||||
|
||||
/// Returns the operation masked by this 'vector.mask'.
|
||||
|
@ -1,10 +1,12 @@
|
||||
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func @func
|
||||
// CHECK-SAME: %[[IN:.*]]: vector<11xf32>
|
||||
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
|
||||
%cst_41 = arith.constant dense<true> : vector<11xi1>
|
||||
// CHECK: vector.mask
|
||||
// CHECK-SAME: vector.yield %arg0
|
||||
// CHECK-NOT: vector.mask
|
||||
// CHECK: return %[[IN]] : vector<11xf32>
|
||||
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
|
||||
return %127 : vector<11xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user