[mlir][Vector] Move vector.mask
canonicalization to folder (#140324)
This MR moves the canonicalization that elides empty `vector.mask` ops to folders.
This commit is contained in:
parent
12c62ebcb2
commit
d6f394e141
@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [
|
|||||||
Location loc);
|
Location loc);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Folds vector.mask ops with an all-true mask.
|
/// Folds empty `vector.mask` with no passthru operand and with or without
|
||||||
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
/// return values. For example:
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
///
|
||||||
MaskFormat maskFormat = getMaskFormat(getMask());
|
/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
|
||||||
if (isEmpty())
|
/// vector<8xi1> -> vector<8xf32>
|
||||||
|
/// %1 = user_op %0 : vector<8xf32>
|
||||||
|
///
|
||||||
|
/// becomes:
|
||||||
|
///
|
||||||
|
/// %0 = user_op %a : vector<8xf32>
|
||||||
|
///
|
||||||
|
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
|
||||||
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
|
if (!maskOp.isEmpty() || maskOp.hasPassthru())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Block *block = maskOp.getMaskBlock();
|
||||||
|
auto terminator = cast<vector::YieldOp>(block->front());
|
||||||
|
if (terminator.getNumOperands() == 0) {
|
||||||
|
// `vector.mask` has no results, just remove the `vector.mask`.
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// `vector.mask` has results, propagate the results.
|
||||||
|
llvm::append_range(results, terminator.getOperands());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
||||||
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
|
if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
MaskFormat maskFormat = getMaskFormat(getMask());
|
||||||
if (maskFormat != MaskFormat::AllTrue)
|
if (maskFormat != MaskFormat::AllTrue)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Elides empty vector.mask operations with or without return values. Propagates
|
|
||||||
// the yielded values by the vector.yield terminator, if any, or erases the op,
|
|
||||||
// otherwise.
|
|
||||||
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(MaskOp maskOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
|
|
||||||
if (maskingOp.getMaskableOp())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!maskOp.isEmpty())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Block *block = maskOp.getMaskBlock();
|
|
||||||
auto terminator = cast<vector::YieldOp>(block->front());
|
|
||||||
if (terminator.getNumOperands() == 0)
|
|
||||||
rewriter.eraseOp(maskOp);
|
|
||||||
else
|
|
||||||
rewriter.replaceOp(maskOp, terminator.getOperands());
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
||||||
MLIRContext *context) {
|
|
||||||
results.add<ElideEmptyMaskOp>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaskingOpInterface definitions.
|
// MaskingOpInterface definitions.
|
||||||
|
|
||||||
/// Returns the operation masked by this 'vector.mask'.
|
/// Returns the operation masked by this 'vector.mask'.
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
|
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
|
||||||
|
|
||||||
module {
|
module {
|
||||||
|
// CHECK-LABEL: func @func
|
||||||
|
// CHECK-SAME: %[[IN:.*]]: vector<11xf32>
|
||||||
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
|
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
|
||||||
%cst_41 = arith.constant dense<true> : vector<11xi1>
|
%cst_41 = arith.constant dense<true> : vector<11xi1>
|
||||||
// CHECK: vector.mask
|
// CHECK-NOT: vector.mask
|
||||||
// CHECK-SAME: vector.yield %arg0
|
// CHECK: return %[[IN]] : vector<11xf32>
|
||||||
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
|
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
|
||||||
return %127 : vector<11xf32>
|
return %127 : vector<11xf32>
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user